sign-predictor/analyze_model.ipynb
2023-03-05 16:34:38 +00:00

121 lines
20 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from src.keypoint_extractor import KeypointExtractor\n",
"from src.datasets.finger_spelling_dataset import FingerSpellingDataset\n",
"from src.identifiers import LANDMARKS\n",
"import torch\n",
"from src.model import SPOTER\n",
"from sklearn.metrics import confusion_matrix\n",
"import numpy as np\n",
"import pandas as pd\n",
"import seaborn as sn\n",
"import matplotlib.pyplot as plt\n",
"from src.augmentations import MirrorKeypoints\n",
"\n",
"keypoints_extractor = KeypointExtractor(\"data/fingerspelling/data/\")\n",
"test_set = FingerSpellingDataset(\"data/fingerspelling/data/\", keypoints_extractor, keypoints_identifier=LANDMARKS, subset=\"val\", transform=MirrorKeypoints())\n",
"\n",
"spoter = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2)\n",
"spoter.load_state_dict(torch.load('models/spoter_40.pth'))\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot: >"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA28AAAJMCAYAAABtgJ7QAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAydUlEQVR4nO3de5xVZb0/8O+GgS0ijFyGi9fsaBAqlEhi6EnICwR4yWN5fppA5Uklb2Qlngov2Zial5IoFUFPmWRejpfEo3CU/CmKkOBdsbwitxRGvGwZ9v790U9yAmX2sGfWrMX73Wv9sdfstfZ3fD3t8etnPc+TK5VKpQAAAKBVa5N0AQAAAGya5g0AACAFNG8AAAApoHkDAABIAc0bAABACmjeAAAAUkDzBgAAkAKaNwAAgBTQvAEAAKSA5g0AACAFNG8AAADN7BOf+ETkcrkNjvHjxzf6HlXNWB8AAAARMW/evFi3bt3610888UQcdNBBcdRRRzX6HrlSqVRqjuIAAADYuNNOOy3uuOOOeP755yOXyzXqGskbAABAExQKhSgUCg3O5fP5yOfzH3vd+++/H7/5zW9iwoQJjW7cIlpR87bm+19OugQyZttLH066BACA1Kp//7WkS2iStSv/0mKfVXvFdXHOOec0ODdp0qQ4++yzP/a6W2+9NVatWhVjx44t6/NazWOTmjcqTfMGANB0mrdNK3bavknJ2yGHHBLt27eP22+/vazPazXJGwAAwGYrrtv0eyqkMY3aP3vppZfi3nvvjZtvvrnsz7NVAAAAQAuZNm1a9OjRI0aOHFn2tZI3AAAgO0rFpCv4SMViMaZNmxZjxoyJqqryWzHJGwAAQAu499574+WXX46vf/3rTbpe8gYAANACDj744Nic9SI1bwAAQHYUW+9jk5vLY5MAAAApIHkDAAAyo9SKFyzZXJI3AACAFJC8AQAA2WHOGwAAAEmSvAEAANlhzhsAAABJkrwBAADZUVyXdAXNRvIGAACQApI3AAAgO8x5AwAAIEmSNwAAIDvs8wYAAECSJG8AAEBmlMx5AwAAIEmSNwAAIDvMeQMAACBJmjcAAIAU8NgkAACQHRYsAQAAIEmSNwAAIDuK65KuoNlI3gAAAFJA8gYAAGSHOW8AAAAkSfIGAABkh026AQAASJLkDQAAyA5z3gAAAEiS5A0AAMgOc94AAABIkuQNAADIjFJpXdIlNBvJGwAAQApI3gAAgOyw2iQAAABJkrwBAADZYbVJAAAAkiR5AwAAssOcNwAAAJKkeQMAAEgBj00CAADZUbRJNwlqf+BXY5uf3tzg2Po7P0+6LFLuxBPGxOLn5saauhfiwQduj0F7fybpkkg5Y4pKM6aoNGOKtNO8pcS6pS/H2+d9ff3xzpT/TLokUuyoow6Niy+aFOf9+JIYtM/wWLjoqfjjnb+NmppuSZdGShlTVJoxRaUZU1uQUrHljhZW0ebtiSeeqOTt+LDiuiitWbX+iHfeSroiUuz0U4+Pq6deH9de9/t4+unn46TxZ8Y777wb48YenXRppJQxRaUZU1SaMUUWbHbz9tZbb8WVV14Zn/vc52LAgAGVqImNaNO9d2z9n1fH1t/7ZeSPPi1y23ZPuiRSql27drHXXv1j1uw/rT9XKpVi1uwHYvDggQlWRloZU1SaMUWlGVNbmGKx5Y4W1uTmbc6cOTFmzJjo3bt3XHzxxTFs2LCYO3duJWvj/1v3ynPx3u9/Ee9NPS8Kt14Zbbr2iA4nnB/RfqukSyOFunfvGlVVVbF82coG55cvXxG9etYkVBVpZkxRacYUlWZMkRVlrTa5dOnSmD59ekydOjXq6uriK1/5ShQKhbj11lujX79+jb5PoVCIQqHQ4Nza+nWRr2pbTjlbjHXP/vkfL5a+FO++/Fx0nPjrqBowJOrnzUquMAAAaG1s0h0xevTo6NOnTyxatCguu+yyWLJkSfziF79o0ofW1tZGdXV1g+Nnc59r0r22SO+9E8UVr0ebbr2SroQUWrnyjaivr48ePRs+etujR00sXbYioapIM2OKSjOmqDRjiqxodPN21113xTe+8Y0455xzYuTIkdG2bdNTsokTJ8bq1asbHN8Z/Kkm32+L036raNOtZ5Tq3ky6ElJo7dq1sWDBohg2dL/153K5XAwbul/MnTs/wcpIK2OKSjOmqDRjaguT4TlvjX5s8oEHHoipU6fGwIED49Of/nR87Wtfi6OPbtrqPPl8PvL5fINzazwy+ZHajxwT9U/Ni9KqFZHr3DXaH3R0RLEYaxc+kHRppNSll18V06ZeGvMXLIp58/4cp5x8fHTs2CGmXzsj6dJIKWOKSjOmqDRjiixodPM2ePDgGDx4cFx22WUxY8aMuOaaa2LChAlRLBbjnnvuiR133DE6derUnLVusXLV3WKr/zMhclt3itLbdbHuxafjnclnRrxdl3RppNSNN94WNd27xtk/OiN69aqJhQufjJGjjo3ly1du+mLYCGOKSjOmqDRjaguSQCLWUnKlUqnU1IufffbZmDp1avzXf/1XrFq1Kg466KC47bbbmnSvNd//clPLgI3a9tKHky4BACC16t9/LekSmuS9P/1Xi33WVvt/rcU+K2Iz93nr06dPXHjhhfHqq6/G7373u0rVBAAA0CSl0roWO1raZm/SHRHRtm3bOPzww5ucugEAAPDxytrnDQAAoFXL8Jy3iiRvAAAANC/JGwAAkB0lyRsAAAAJ0rwBAACkgMcmAQCA7LBgCQAAAEmSvAEAANlhwRIAAACSJHkDAACyw5w3AAAAkiR5AwAAssOcNwAAADbHa6+9Fscee2x069YtOnToEHvuuWc8+uijjb5e8gYAAGRHK53z9uabb8aQIUNi6NChcdddd0VNTU08//zz0aVLl0bfQ/MGAADQzH7605/GjjvuGNOmTVt/bpdddinrHh6bBAAAsqNYbLGjUChEXV1dg6NQKGy0rNtuuy323nvvOOqoo6JHjx7x2c9+Nq666qqyfjXNGwAAQBPU1tZGdXV1g6O2tnaj7/3LX/4SU6ZMid122y3uvvvuOPHEE+OUU06Ja6+9ttGflyuVSqVKFb851nz/y0mXQMZse+nDSZcAAJBa9e+/lnQJTfLuHZe02Ge1OWj8BklbPp+PfD6/wXvbt28fe++9dzz44IPrz51yyikxb968eOihhxr1eea8AQAANMFHNWob07t37+jXr1+Dc5/+9KfjpptuavTnad4AAIDsaKWrTQ4ZMiSeffbZBueee+652HnnnRt9D3PeAAAAmtnpp58ec+fOjZ/85CexePHiuP766+PKK6+M8ePHN/oekjcAACA7Sq0zeRs0aFDccsstMXHixDj33HNjl112icsuuyyOOeaYRt9D8wYAANACRo0aFaNGjWry9Zo3AAAgO1rpnLdKMOcNAAAgBTRvAAAAKeCxSQAAIDta6YIllSB5AwAASAHJGwAAkB0WLAEAACBJkjcAACA7JG8AAAAkSfIGAABkR6mUdAXNRvIGAACQApI3AAAgO8x5AwAAIEmSNwAAIDskbwAAACRJ8gYAAGRHSfIGAABAgiRvAABAdpjzBgAAQJIkbwAAQHaUSklX0GwkbwAAACmgeQMAAEgBj00CAADZYcESAAAAktRqkrdtL3046RLImLdu+W7SJZAhnY64KOkSAIDGkLwBAACQpFaTvAEAAGy2kuQNAACABEneAACAzCgVbdINAABAgiRvAABAdlhtEgAAgCRJ3gAAgOyw2iQAAABJkrwBAADZYbVJAAAAkiR5AwAAssNqkwAAACRJ8gYAAGSH5A0AAIAkad4AAABSwGOTAABAdpRsFQAAAECCJG8AAEB2WLAEAACAJEneAACA7Cia8wYAAECCJG8AAEB2lMx5AwAAIEGSNwAAIDvMeQMAACBJkjcAACAzSvZ5AwAAIEmSNwAAIDvMeQMAACBJkjcAACA77PMGAABAkiRvAABAdpjzBgAAQJIkbwAAQHbY5w0AAIAkad4AAABSwGOTAABAdliwBAAAgCRJ3gAAgOywSTcAAABJ0rwBAADZUSy13FGGs88+O3K5XIOjb9++Zd3DY5MAAAAtYPfdd4977713/euqqvLaMc0bAACQGaVWvEl3VVVV9OrVq8nXe2wSAACgCQqFQtTV1TU4CoXCR77/+eefj+222y4++clPxjHHHBMvv/xyWZ+neQMAALKjBee81dbWRnV1dYOjtrZ2o2Xts88+MX369Jg5c2ZMmTIl/vrXv8b+++8fb731VqN/tVypVGoVu9hVtd8+6RLImLdu+W7SJZAhnY64KOkSAKBF1b//WtIlNMma73+5xT6r3bm/2yBpy+fzkc/nN3ntqlWrYuedd45LLrkkvvGNbzTq8yRvKXLiCWNi8XNzY03dC/HgA7fHoL0/k3RJpNSyVWvirN/cG1/4wTWxz/eujH+7cEY8+crypMsi5XxHUWnGFJVmTG0hWjB5y+fz0blz5wZHYxq3iIhtt902PvWpT8XixYsb/atp3lLiqKMOjYsvmhTn/fiSGLTP8Fi46Kn4452/jZqabkmXRsrUvVOIsb+4Naratokrjh8ZN3//6Jhw2Oejc4fGfdHAxviOotKMKSrNmKK1WbNmTbzwwgvRu3fvRl/jscmUePCB22Peowvj1NN+EBERuVwuXvzLvJj8y2lx4UWTE66udfLY5MZdfsfceOyvr8e0k49IupRU8djkx/MdRaUZU1SaMVW+1D42ecZhLfZZ21z8341+7xlnnBGjR4+OnXfeOZYsWRKTJk2Kxx57LJ566qmoqalp1D0kbynQrl272Guv/jFr9p/WnyuVSjFr9gMxePDABCsjje5/8sXot2OPOOPau2Poj6bFV392Y9z00FNJl0WK+Y6i0owpKs2YojV49dVX49///d+jT58+8ZWvfCW6desWc+fObXTjFtHEfd7+9re/Rbduf4+YX3nllbjqqqvi3XffjUMPPTT233//ptySj9G9e9eoqqqK5ctWNji/fPmK6NvnXxKqirR69W91ceODT8axX+gf3/ziXvHEKyviwlseiHZVbeLQQX2TLo8U8h1FpRlTVJoxtYUptooHCzdwww03bPY9ymreHn/88Rg9enS88sorsdtuu8UNN9wQw4cPj7fffjvatGkTl156afzhD3+Iww8//GPvUygUNliVpVQqRS6XK/sXAMpTLJWi3441ccrIwRER0XeHmnjh9TfiDw8+pXkDAGjFynps8nvf+17sueeeMWfOnDjggANi1KhRMXLkyFi9enW8+eab8a1vfSsuuOCCTd5nY/shlIqN399gS7Ny5RtRX18fPXp2b3C+R4+aWLpsRUJVkVY1nbeOf+nZpcG5XXpuG6+/uSahikg731FUmjFFpRlTW5ZSsdRiR0srq3mbN29enH/++TFkyJC4+OKLY8mSJXHSSSdFmzZtok2bNnHyySfHM888s8n7TJw4MVavXt3gyLXp1ORfIuvWrl0bCxYsimFD91t/LpfLxbCh+8XcufMTrIw0GvCJXvHi8lUNzr20YnX07rpNMgWRer6jqDRjikozpsiKsh6bfOONN6JXr14REbHNNttEx44do0uXf/wX/C5dujRqh/CNbVznkcmPd+nlV8W0qZfG/AWLYt68P8cpJx8fHTt2iOnXzki6NFLm2C8MiLE/vyWuvnd+HDxg13ji5WVx09yn4odHfSHp0kgx31FUmjFFpRlTZEHZC5b8c5Ol6WoZN954W9R07xpn/+iM6NWrJhYufDJGjjo2li9fuemL4UP22KlHXDLukPj5nQ/Hlf8zP7bv2im+e9iQGDnwU0mXRor5jqLSjCkqzZjagrTSBUsqoax93tq0aRMjRoxYn5rdfvvtMWzYsOjYsWNE/H0hkpkzZ8a6devKLsQ+b1Safd6oJPu8AbClSes+b2+dMqrFPqvTz+9osc+KKDN5GzNmTIPXxx577AbvOe644zavIgAAgKYqFpOuoNmU1bxNmzatueoAAADgYzRpk24AAIBWKcNz3sraKgAAAIBkSN4AAIDskLwBAACQJMkbAACQGWXshJY6kjcAAIAUkLwBAADZYc4bAAAASZK8AQAA2SF5AwAAIEmSNwAAIDNKkjcAAACSJHkDAACyQ/IGAABAkiRvAABAdhSTLqD5SN4AAABSQPMGAACQAh6bBAAAMsNWAQAAACRK8gYAAGSH5A0AAIAkSd4AAIDssFUAAAAASZK8AQAAmWG1SQAAABIleQMAALLDnDcAAACSJHkDAAAyw5w3AAAAEiV5AwAAssOcNwAAAJIkeQMAADKjJHkDAAAgSZI3AAAgOyRvAAAAJEnzBgAAkAIemwQAADLDgiUAAAAkSvIGAABkh+QNAACAJEneAACAzDDnDQAAgERJ3gAAgMyQvAEAAJAoyRsAAJAZkjcAAAASJXkDAACyo5RLuoJmo3kjszodcVHSJZAhb93y3aRLIGN8R1Fpe3XfNekSgGameQMAADLDnDcAAAASJXkDAAAyo1TM7pw3yRsAAEAKSN4AAIDMMOcNAACAREneAACAzChleJ83yRsAAEAKaN4AAABSwGOTAABAZliwBAAAgIq54IILIpfLxWmnndboayRvAABAZqRhk+558+bFr3/96+jfv39Z10neAAAAWsiaNWvimGOOiauuuiq6dOlS1rWaNwAAIDNKpZY7CoVC1NXVNTgKhcLH1jd+/PgYOXJkHHjggWX/bpo3AACAJqitrY3q6uoGR21t7Ue+/4YbbogFCxZ87Hs+jjlvAABAZrTknLeJEyfGhAkTGpzL5/Mbfe8rr7wSp556atxzzz2x1VZbNenzNG8AAABNkM/nP7JZ+2fz58+P5cuXx1577bX+3Lp162LOnDlxxRVXRKFQiLZt237sPTRvAABAZrTW1Sa/+MUvxuOPP97g3Lhx46Jv377x/e9/f5ONW4TmDQAAoNl16tQp9thjjwbnOnbsGN26ddvg/EfRvAEAAJlRKiVdQfPRvAEAACTgvvvuK+v9mjcAACAzWuuct0qwzxsAAEAKSN4AAIDMKJUkbwAAACRI8gYAAGRGqZh0Bc1H8gYAAJACmjcAAIAU8NgkAACQGUULlgAAAJAkyRsAAJAZtgoAAAAgUZI3AAAgM0pFyRsAAAAJkrwBAACZUSolXUHzkbwBAACkgOQNAADIDHPeAAAASJTkDQAAyIyifd4AAABIkuQNAADIjJLkDQAAgCRJ3gAAgMywzxsAAACJkrwBAACZYbVJAAAAEiV5AwAAMsNqk7QKJ54wJhY/NzfW1L0QDz5wewza+zNJl0SKGU9U0rJVa+Ks39wbX/jBNbHP966Mf7twRjz5yvKkyyLlfE9RKZ/dp3/87NrauHPBTfHIkvvjC8P3S7okaBLNW0ocddShcfFFk+K8H18Sg/YZHgsXPRV/vPO3UVPTLenSSCHjiUqqe6cQY39xa1S1bRNXHD8ybv7+0THhsM9H5w75pEsjxXxPUUlbbd0hnn9ycVx01mVJlwKbRfOWEqefenxcPfX6uPa638fTTz8fJ40/M955590YN/bopEsjhYwnKmna7D9Hr207xrn/Piz23LlnbN+tc3y+z46xY/fqpEsjxXxPUUkP/e/D8asLp8Z9M/+UdCm0gFKp5Y6WVlbzNnv27OjXr1/U1dVt8LPVq1fH7rvvHn/6k/9TVFq7du1ir736x6zZ//hnWyqVYtbsB2Lw4IEJVkYaGU9U2v1Pvhj9duwRZ1x7dwz90bT46s9ujJseeirpskgx31MAG1dW83bZZZfF8ccfH507d97gZ9XV1fGtb30rLrnkkooVx9917941qqqqYvmylQ3OL1++Inr1rEmoKtLKeKLSXv1bXdz44JOxU/fqmPIfo+Koz+8eF97yQNw275mkSyOlfE8Bm6NYyrXY0dLKat4WLlwYw4cP/8ifH3zwwTF//vxN3qdQKERdXV2Do5TlrdABMqxYKkXfHbrHKSMHR98dauLf9u0XXx7cL/7woPQNACqprOZt2bJl0a5du4/8eVVVVaxYsWKT96mtrY3q6uoGR6n4VjmlbFFWrnwj6uvro0fP7g3O9+hRE0uXbfqfN3yY8USl1XTeOv6lZ5cG53bpuW28/uaahCoi7XxPAZujVMq12NHSymrett9++3jiiSc+8ueLFi2K3r17b/I+EydOjNWrVzc4cm06lVPKFmXt2rWxYMGiGDb0H8va5nK5GDZ0v5g7d9NJJ3yY8USlDfhEr3hx+aoG515asTp6d90mmYJIPd9TABtX1ibdX/rSl+KHP/xhDB8+PLbaaqsGP3v33Xdj0qRJMWrUqE3eJ5/PRz7fcAnpXC67m+lVwqWXXxXTpl4a8xcsinnz/hynnHx8dOzYIaZfOyPp0kgh44lKOvYLA2Lsz2+Jq++dHwcP2DWeeHlZ3DT3qfjhUV9IujRSzPcUldRh6w6xwy7br3+93Y69Y7fdd426VXWx7DV7UmZNEnPRWkpZzdsPfvCDuPnmm+NTn/pUfPvb344+ffpERMQzzzwTkydPjnXr1sV//ud/NkuhW7obb7wtarp3jbN/dEb06lUTCxc+GSNHHRvLl6/c9MXwT4wnKmmPnXrEJeMOiZ/f+XBc+T/zY/uuneK7hw2JkQM/lXRppJjvKSrp0wP6xK9uunz969PP+XZERNwx46449/QLkioLypYrlblSyEsvvRQnnnhi3H333esXGcnlcnHIIYfE5MmTY5dddmlSIVXtt9/0mwAS8tYt3026BDKm0xEXJV0CGbNX912TLoGMeWTJ/UmX0CRzt/tyi33W4CU3t9hnRZSZvEVE7LzzzvHHP/4x3nzzzVi8eHGUSqXYbbfdokuXLpu+GAAAgCYpu3n7QJcuXWLQoEGVrAUAAGCzZHnOW1mrTQIAAJCMJidvAAAArU0S+6+1FMkbAABACkjeAACAzCgmXUAzkrwBAACkgOQNAADIjFKY8wYAAECCNG8AAAAp4LFJAAAgM4qlpCtoPpI3AACAFJC8AQAAmVG0YAkAAABJkrwBAACZYasAAAAAEiV5AwAAMqOYdAHNSPIGAACQApI3AAAgM8x5AwAAIFGSNwAAIDPMeQMAACBRkjcAACAzJG8AAAAkSvIGAABkhtUmAQAASJTkDQAAyIxidoM3yRsAAEAaSN4AAIDMKJrzBgAAQFNNmTIl+vfvH507d47OnTvHvvvuG3fddVdZ99C8AQAANLMddtghLrjggpg/f348+uijMWzYsDjssMPiySefbPQ9PDYJAABkRinpAj7C6NGjG7w+//zzY8qUKTF37tzYfffdG3UPzRsAAEATFAqFKBQKDc7l8/nI5/Mfe926devixhtvjLfffjv23XffRn+exyYBAIDMKLbgUVtbG9XV1Q2O2traj6zt8ccfj2222Sby+XyccMIJccstt0S/fv0a/btJ3gAAAJpg4sSJMWHChAbnPi5169OnTzz22GOxevXq+MMf/hBjxoyJ+++/v9ENnOYNAADIjGKu5bYKaMwjkh/Wvn372HXXXSMiYuDAgTFv3ry4/PLL49e//nWjrvfYJAAAQAKKxeIGc+Y+juQNAADIjNa62uTEiRNjxIgRsdNOO8Vbb70V119/fdx3331x9913N/oemjcAAIBmtnz58jjuuOPi9ddfj+rq6ujfv3/cfffdcdBBBzX6Hpo3AAAgM4pJF/ARpk6dutn3MOcNAAAgBSRvAABAZhRbbrHJFid5AwAASAHJGwAAkBnFyG70JnkDAABIAckbAACQGa11n7dKkLwBAACkgOQNAADIDKtNAgAAkKhWk7ydvN3+SZdAxvxiyZ+SLoEM6XTERUmXQMa86zuKChvSf1zSJQDNrNU0bwAAAJurmHQBzchjkwAAACkgeQMAADLDVgEAAAAkSvIGAABkhq0CAAAASJTkDQAAyAyrTQIAAJAoyRsAAJAZkjcAAAASJXkDAAAyo2S1SQAAAJIkeQMAADLDnDcAAAASJXkDAAAyQ/IGAABAoiRvAABAZpSSLqAZSd4AAABSQPIGAABkRtE+bwAAACRJ8wYAAJACHpsEAAAyw1YBAAAAJEryBgAAZIbkDQAAgERJ3gAAgMywSTcAAACJkrwBAACZYZNuAAAAEiV5AwAAMsNqkwAAACRK8gYAAGSG1SYBAABIlOQNAADIjGKGszfJGwAAQApI3gAAgMyw2iQAAACJkrwBAACZkd0Zb5I3AACAVNC8AQAApIDHJgEAgMywYAkAAACJkrwBAACZUcwlXUHzkbwBAACkgOQNAADIjGKGNwuQvAEAAKSA5A0AAMiM7OZukrdU+OJJh8Xp/31+1D4xLc599Nfx9Su/EzWf7J10WaTciSeMicXPzY01dS/Egw/cHoP2/kzSJZFyxhSVcvCRY2KPISM2OH78s8lJl0ZKfXaf/vGza2vjzgU3xSNL7o8vDN8v6ZKgSTRvKfAv+3w6Hviv/4nLj/hh/Opr50fbqrZxwnVnRfsO+aRLI6WOOurQuPiiSXHejy+JQfsMj4WLnoo/3vnbqKnplnRppJQxRSXdcPXlcd9tv11/XHXZTyIi4uCh+ydcGWm11dYd4vknF8dFZ12WdCm0gGILHi1N85YCV465IOb94f5Y+vyrseTpl+P6M6ZE1x1qYoc9d0m6NFLq9FOPj6unXh/XXvf7ePrp5+Ok8WfGO++8G+PGHp10aaSUMUUlde2ybXTv1nX9cf//fTh23L53DPrsnkmXRko99L8Px68unBr3zfxT0qXAZim7eSsWi3HNNdfEqFGjYo899og999wzDj300LjuuuuiVMryE6atR4dOW0dExDur1iRcCWnUrl272Guv/jFr9j/+gJVKpZg1+4EYPHhggpWRVsYUzWnt2rVxx//8bxwx8uDI5TK8eRNQMcUotdjR0spq3kqlUhx66KHxzW9+M1577bXYc889Y/fdd4+XXnopxo4dG0cccURz1cn/l8vl4vAfjYm/zHsmlj73atLlkELdu3eNqqqqWL5sZYPzy5eviF49axKqijQzpmhOs+Y8FG+tWROHf+mgpEsBSFxZq01Onz495syZE7NmzYqhQ4c2+Nns2bPj8MMPj+uuuy6OO+64j71PoVCIQqHQ4Fx9aV1U5dqWU84W6cjzvh69++wYP/+3SUmXAgDN7uY77o79Bu8dPcyfBBopy88ClpW8/e53v4uzzjprg8YtImLYsGFx5plnxm9/+9tN3qe2tjaqq6sbHPNWP11OKVukL58zLvoN2ysmH31urF76RtLlkFIrV74R9fX10aNn9wbne/SoiaXLViRUFWlmTNFclixdFnMffSyOHD086VIAWoWymrdFixbF8OEf/QU6YsSIWLhw4SbvM3HixFi9enWDY1D1p8spZYvz5XPGxZ6HDIpf/p/z4o1X/csQTbd27dpYsGBRDBv6j2WSc7lcDBu6X8ydOz/BykgrY4rmcsud90TXLtXxr/t+LulSgBTJ8mqTZT02+cYbb0TPnj0/8uc9e/aMN998c5P3yefzkc83XObeI5Mf7cjzvh4DDxsSU4+/OApvvxudaqojIuK9undibWFtwtWRRpdeflVMm3ppzF+wKObN+3OccvLx0bFjh5h+7YykSyOljCkqrVgsxq133hOHjTgwqqr8OwKbp8PWHWKHXbZf/3q7HXvHbrvvGnWr6mLZa8sTrAzKU1bztm7duqiq+uhL2rZtG/X19ZtdFA3t97WDIyLi2zMaznO7/owpMe8P9ydREil34423RU33rnH2j86IXr1qYuHCJ2PkqGNj+fKVm74YNsKYotIemvfneH3Z8jhi5MFJl0IGfHpAn/jVTZevf336Od+OiIg7ZtwV555+QVJl0UySWAWyMWpra+Pmm2+OZ555Jjp06BCf//zn46c//Wn06dOn0ffIlcpY379NmzYxYsSIDVKzDxQKhZg5c2asW7eu0QV84PRP2AuIyvrFEnu5AK3Xu76jqLAh/cclXQIZ88iSdIYEE1qwr7jkxRsa/d7hw4fH0UcfHYMGDYr6+vo466yz4oknnoinnnoqOnbs2Kh7lJW8jRkzZpPv2dRKkwAAAM2ldeZuETNnzmzwevr06dGjR4+YP39+/Ou//muj7lFW8zZt2rRy3g4AAJBZG9sCbWPre2zM6tWrIyKia9eujf68slabBAAA4O82tgVabW3tJq8rFotx2mmnxZAhQ2KPPfZo9OeVlbwBAAC0Zi25hP/EiRNjwoQJDc41JnUbP358PPHEE/HAAw+U9XmaNwAAgCZo7COSH/btb3877rjjjpgzZ07ssMMOZV2reQMAADKj1EqXLCmVSnHyySfHLbfcEvfdd1/ssssuZd9D8wYAANDMxo8fH9dff33893//d3Tq1CmWLl0aERHV1dXRoUOHRt1D8wYAAGRGS855K8eUKVMiIuKAAw5ocH7atGkxduzYRt1D8wYAANDMSqXNf5xT8wYAAGRGsZXOeasE+7wBAACkgOQNAADIjOzmbpI3AACAVJC8AQAAmWHOGwAAAImSvAEAAJnRWvd5qwTJGwAAQApI3gAAgMwomfMGAABAkiRvAABAZpjzBgAAQKI0bwAAACngsUkAACAzLFgCAABAoiRvAABAZliwBAAAgERJ3gAAgMwolsx5AwAAIEGSNwAAIDOym7tJ3gAAAFJB8gYAAGRGMcPZm+QNAAAgBSRvAABAZpQkbwAAACRJ8gYAAGRGMekCmpHkDQAAIAUkbwAAQGZYbRIAAIBESd4AAIDMsNokAAAAiZK8AQAAmWG1SQAAABKleQMAAEgBj00CAACZUSpZsAQAAIAESd4AAIDMsEk3AAAAiZK8AQAAmZHlrQJaTfP2iyV/SroEAGgxHbbbP+kSyJi/9O+bdAlAM2s1zRsAAMDmKpnzBgAAQJIkbwAAQGZYbRIAAIBESd4AAIDMKJUkbwAAACRI8gYAAGRGlvd5k7wBAACkgOQNAADIDPu8AQAAkCjJGwAAkBn2eQMAACBRmjcAAIAU8NgkAACQGTbpBgAAIFGSNwAAIDMsWAIAAECiJG8AAEBm2KQbAACAREneAACAzChabRIAAIAkSd4AAIDMyG7uJnkDAABIBckbAACQGfZ5AwAAIFGSNwAAIDMkbwAAADTZnDlzYvTo0bHddttFLpeLW2+9tex7aN4AAIDMKJVKLXaU4+23344BAwbE5MmTm/y7eWwSAACgmY0YMSJGjBixWffQvAEAAJnRknPeCoVCFAqFBufy+Xzk8/lm+TyPTQIAADRBbW1tVFdXNzhqa2ub7fMkbwAAQGaUWjB5mzhxYkyYMKHBueZK3SI0bwAAAE3SnI9IbozHJgEAAFJA8gYAAGRGuUv4t5Q1a9bE4sWL17/+61//Go899lh07do1dtppp0bdQ/MGAADQzB599NEYOnTo+tcfzJUbM2ZMTJ8+vVH30LwBAACZ0ZJbBZTjgAMO2OxU0Jw3AACAFJC8AQAAmdFa57xVguQNAAAgBSRvAABAZrTWOW+VIHkDAABIAckbAACQGSXJGwAAAEmSvAEAAJlRtNokAAAASZK8AQAAmWHOGwAAAInSvKXIiSeMicXPzY01dS/Egw/cHoP2/kzSJZFixhOVZkxRacYUzaXzmKNjp0dnxbYTTkq6FJpBsVRqsaOlad5S4qijDo2LL5oU5/34khi0z/BYuOip+OOdv42amm5Jl0YKGU9UmjFFpRlTNJf2/frENl8eFe8/90LSpUDZNG8pcfqpx8fVU6+Pa6/7fTz99PNx0vgz45133o1xY49OujRSyHii0owpKs2YojnkOmwV3c47K/52/iVRfOutpMuhmZRa8H8tTfOWAu3atYu99uofs2b/af25UqkUs2Y/EIMHD0ywMtLIeKLSjCkqzZiiuXT5/qnx7v+dG4VHFiRdCjRJWc3bl770pVi9evX61xdccEGsWrVq/eu//e1v0a9fv03ep1AoRF1dXYOjlOH9GDZX9+5do6qqKpYvW9ng/PLlK6JXz5qEqiKtjCcqzZii0owpmsPWBw+N9n13jVVXXJ10KdBkZTVvd999dxQKhfWvf/KTn8Qbb7yx/nV9fX08++yzm7xPbW1tVFdXNzhKRdE1AACV17ZnTXT5zvj42w9qI95fm3Q5NLMsL1hS1j5v/5yONTUtmzhxYkyYMKHBuS7d+jbpXluClSvfiPr6+ujRs3uD8z161MTSZSsSqoq0Mp6oNGOKSjOmqLT2fT8Vbbt1iV6/+dX6c7mqtpH/bP/o9JXD45XPD48oFhOsEBonkTlv+Xw+Onfu3ODI5XJJlJIKa9eujQULFsWwofutP5fL5WLY0P1i7tz5CVZGGhlPVJoxRaUZU1Tae/MWxOtf/UYsPeY/1h+FJ5+Jd2bOiqXH/IfGLWOyvGBJWclbLpfboMnSdLWMSy+/KqZNvTTmL1gU8+b9OU45+fjo2LFDTL92RtKlkULGE5VmTFFpxhSVVHrn3Vj7wosNz733XqxbVbfBeWjNyn5scuzYsZHP5yMi4r333osTTjghOnbsGBHRYD4clXXjjbdFTfeucfaPzohevWpi4cInY+SoY2P58pWbvhj+ifFEpRlTVJoxBTRVEnPRWkquVMbEtXHjxjXqfdOmTSu7kKr225d9DQAAf/eX/tYPoLJ2enRW0iU0yb9036vFPuuFlS277URZyVtTmjIAAICWksRctJZik24AAIAUKCt5AwAAaM1KpeyuHip5AwAASAHJGwAAkBlFc94AAABIkuQNAADIjDJ2QksdyRsAAEAKSN4AAIDMMOcNAACAREneAACAzDDnDQAAgERJ3gAAgMwoSt4AAABIkuYNAAAgBTw2CQAAZEbJVgEAAAAkSfIGAABkhq0CAAAASJTkDQAAyIyiOW8AAAAkSfIGAABkhjlvAAAAJEryBgAAZEZR8gYAAECSJG8AAEBmmPMGAABAoiRvAABAZtjnDQAAgERJ3gAAgMww5w0AAIBESd4AAIDMsM8bAAAAidK8AQAApIDHJgEAgMwo2SoAAACAJEneAACAzLBgCQAAAImSvAEAAJlhk24AAAASJXkDAAAyw2qTAAAAJEryBgAAZIY5bwAAACRK8wYAAGRGqVRqsaMpJk+eHJ/4xCdiq622in322SceeeSRRl+reQMAAGgBM2bMiAkTJsSkSZNiwYIFMWDAgDjkkENi+fLljbpe8wYAAGRGqQWPcl1yySVx/PHHx7hx46Jfv37xq1/9Krbeeuu45pprGnW95g0AAKAJCoVC1NXVNTgKhcJG3/v+++/H/Pnz48ADD1x/rk2bNnHggQfGQw891KjPazWrTda//1rSJbR6hUIhamtrY+LEiZHP55Muhwwwpqg0Y4pKM6aoJONpy9CSfcXZZ58d55xzToNzkyZNirPPPnuD965cuTLWrVsXPXv2bHC+Z8+e8cwzzzTq83KlLK+lmTF1dXVRXV0dq1evjs6dOyddDhlgTFFpxhSVZkxRScYTlVYoFDZI2vL5/Eb/48CSJUti++23jwcffDD23Xff9ee/973vxf333x8PP/zwJj+v1SRvAAAAafJRjdrGdO/ePdq2bRvLli1rcH7ZsmXRq1evRt3DnDcAAIBm1r59+xg4cGDMmjVr/blisRizZs1qkMR9HMkbAABAC5gwYUKMGTMm9t577/jc5z4Xl112Wbz99tsxbty4Rl2veUuRfD4fkyZNMsGWijGmqDRjikozpqgk44mkffWrX40VK1bEj370o1i6dGl85jOfiZkzZ26wiMlHsWAJAABACpjzBgAAkAKaNwAAgBTQvAEAAKSA5g0AACAFNG8p8dBDD0Xbtm1j5MiRSZdCyo0dOzZyudz6o1u3bjF8+PBYtGhR0qWRckuXLo2TTz45PvnJT0Y+n48dd9wxRo8e3WA/G2iMD39PtWvXLnr27BkHHXRQXHPNNVEsFpMujxT65799HxzDhw9PujQoi+YtJaZOnRonn3xyzJkzJ5YsWZJ0OaTc8OHD4/XXX4/XX389Zs2aFVVVVTFq1KikyyLFXnzxxRg4cGDMnj07Lrroonj88cdj5syZMXTo0Bg/fnzS5ZFCH3xPvfjii3HXXXfF0KFD49RTT41Ro0ZFfX190uWRQh/+2/fB8bvf/S7psqAs9nlLgTVr1sSMGTPi0UcfjaVLl8b06dPjrLPOSrosUiyfz0evXr0iIqJXr15x5plnxv777x8rVqyImpqahKsjjU466aTI5XLxyCOPRMeOHdef33333ePrX/96gpWRVh/+ntp+++1jr732isGDB8cXv/jFmD59enzzm99MuELS5sNjCtJK8pYCv//976Nv377Rp0+fOPbYY+Oaa64J2/NRKWvWrInf/OY3seuuu0a3bt2SLocUeuONN2LmzJkxfvz4Bo3bB7bddtuWL4pMGjZsWAwYMCBuvvnmpEsBSITmLQWmTp0axx57bET8PfJfvXp13H///QlXRZrdcccdsc0228Q222wTnTp1ittuuy1mzJgRbdr4SqB8ixcvjlKpFH379k26FLYAffv2jRdffDHpMkihD//t++D4yU9+knRZUBaPTbZyzz77bDzyyCNxyy23REREVVVVfPWrX42pU6fGAQcckGxxpNbQoUNjypQpERHx5ptvxi9/+csYMWJEPPLII7HzzjsnXB1p40kAWlKpVIpcLpd0GaTQh//2faBr164JVQNNo3lr5aZOnRr19fWx3XbbrT9XKpUin8/HFVdcEdXV1QlWR1p17Ngxdt111/Wvr7766qiuro6rrroqfvzjHydYGWm02267RS6Xi2eeeSbpUtgCPP3007HLLrskXQYp9M9/+yCNPCPVitXX18d1110XP/vZz+Kxxx5bfyxcuDC22247KyRRMblcLtq0aRPvvvtu0qWQQl27do1DDjkkJk+eHG+//fYGP1+1alXLF0UmzZ49Ox5//PE48sgjky4FIBGSt1bsjjvuiDfffDO+8Y1vbJCwHXnkkTF16tQ44YQTEqqONCsUCrF06dKI+Ptjk1dccUWsWbMmRo8enXBlpNXkyZNjyJAh8bnPfS7OPffc6N+/f9TX18c999wTU6ZMiaeffjrpEkmZD76n1q1bF8uWLYuZM2dGbW1tjBo1Ko477rikyyOFPvy37wNVVVXRvXv3hCqC8mneWrGpU6fGgQceuNFHI4888si48MILY9GiRdG/f/8EqiPNZs6cGb17946IiE6dOkXfvn3jxhtvNI+SJvvkJz8ZCxYsiPPPPz++853vxOuvvx41NTUxcODADeaYQGN88D1VVVUVXbp0iQEDBsTPf/7zGDNmjMWVaJIP/+37QJ8+fTzyTarkSmaaAwAAtHr+0xUAAEAKaN4AAABSQPMGAACQApo3AACAFNC8AQAApIDmDQAAIAU0bwAAACmgeQMAAEgBzRsAAEAKaN4AAABSQPMGAACQApo3AACAFPh/JjvFdS+oC5gAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 1200x700 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"y_pred = []\n",
"y_true = []\n",
"\n",
"for inputs, label in test_set:\n",
" output = spoter(inputs) # Feed Network\n",
"\n",
" output = torch.argmax(output, dim=2) # Get Prediction\n",
" output = output.detach().numpy().tolist()[0] # Convert to list\n",
" \n",
" y_true.extend(output) # Save Truth\n",
"\n",
" y_pred.append(label) # Save Prediction\n",
"\n",
"# constant for classes\n",
"classes = ('A', 'B', 'C', 'D', 'E')\n",
"\n",
"# Build confusion matrix\n",
"cf_matrix = confusion_matrix(y_true, y_pred)\n",
"df_cm = pd.DataFrame(cf_matrix, index = [i for i in classes],\n",
" columns = [i for i in classes])\n",
"plt.figure(figsize = (12,7))\n",
"sn.heatmap(df_cm, annot=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}