diff --git a/.drone.yml b/.drone.yml index c41e13e..30ff4a1 100644 --- a/.drone.yml +++ b/.drone.yml @@ -7,7 +7,7 @@ steps: pull: if-not-exists image: sonarsource/sonar-scanner-cli commands: - - sonar-scanner -Dsonar.host.url=$SONAR_HOST -Dsonar.login=$SONAR_TOKEN -Dsonar.projectKey=$SONAR_PROJECT_KEY + - sonar-scanner -Dsonar.host.url=$SONAR_HOST -Dsonar.login=$SONAR_TOKEN -Dsonar.projectKey=$SONAR_PROJECT_KEY -Dsonar.qualitygate.wait=true environment: SONAR_HOST: from_secret: sonar_host diff --git a/.gitignore b/.gitignore index 6147721..8f8a4ba 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ cache_wlasl/ __pycache__/ -checkpoints/ \ No newline at end of file +checkpoints/ +.ipynb_checkpoints \ No newline at end of file diff --git a/analyze_model.ipynb b/analyze_model.ipynb new file mode 100644 index 0000000..f02de7a --- /dev/null +++ b/analyze_model.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from src.keypoint_extractor import KeypointExtractor\n", + "from src.datasets.finger_spelling_dataset import FingerSpellingDataset\n", + "from src.identifiers import LANDMARKS\n", + "import torch\n", + "from src.model import SPOTER\n", + "from sklearn.metrics import confusion_matrix\n", + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sn\n", + "import matplotlib.pyplot as plt\n", + "from src.augmentations import MirrorKeypoints\n", + "\n", + "keypoints_extractor = KeypointExtractor(\"data/fingerspelling/data/\")\n", + "test_set = FingerSpellingDataset(\"data/fingerspelling/data/\", keypoints_extractor, keypoints_identifier=LANDMARKS, subset=\"val\", transform=MirrorKeypoints())\n", + "\n", + "spoter = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2)\n", + "spoter.load_state_dict(torch.load('models/spoter_40.pth'))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA28AAAJMCAYAAABtgJ7QAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAydUlEQVR4nO3de5xVZb0/8O+GgS0ijFyGi9fsaBAqlEhi6EnICwR4yWN5fppA5Uklb2Qlngov2Zial5IoFUFPmWRejpfEo3CU/CmKkOBdsbwitxRGvGwZ9v790U9yAmX2sGfWrMX73Wv9sdfstfZ3fD3t8etnPc+TK5VKpQAAAKBVa5N0AQAAAGya5g0AACAFNG8AAAApoHkDAABIAc0bAABACmjeAAAAUkDzBgAAkAKaNwAAgBTQvAEAAKSA5g0AACAFNG8AAADN7BOf+ETkcrkNjvHjxzf6HlXNWB8AAAARMW/evFi3bt3610888UQcdNBBcdRRRzX6HrlSqVRqjuIAAADYuNNOOy3uuOOOeP755yOXyzXqGskbAABAExQKhSgUCg3O5fP5yOfzH3vd+++/H7/5zW9iwoQJjW7cIlpR87bm+19OugQyZttLH066BACA1Kp//7WkS2iStSv/0mKfVXvFdXHOOec0ODdp0qQ4++yzP/a6W2+9NVatWhVjx44t6/NazWOTmjcqTfMGANB0mrdNK3bavknJ2yGHHBLt27eP22+/vazPazXJGwAAwGYrrtv0eyqkMY3aP3vppZfi3nvvjZtvvrnsz7NVAAAAQAuZNm1a9OjRI0aOHFn2tZI3AAAgO0rFpCv4SMViMaZNmxZjxoyJqqryWzHJGwAAQAu499574+WXX46vf/3rTbpe8gYAANACDj744Nic9SI1bwAAQHYUW+9jk5vLY5MAAAApIHkDAAAyo9SKFyzZXJI3AACAFJC8AQAA2WHOGwAAAEmSvAEAANlhzhsAAABJkrwBAADZUVyXdAXNRvIGAACQApI3AAAgO8x5AwAAIEmSNwAAIDvs8wYAAECSJG8AAEBmlMx5AwAAIEmSNwAAIDvMeQMAACBJmjcAAIAU8NgkAACQHRYsAQAAIEmSNwAAIDuK65KuoNlI3gAAAFJA8gYAAGSHOW8AAAAkSfIGAABkh026AQAASJLkDQAAyA5z3gAAAEiS5A0AAMgOc94AAABIkuQNAADIjFJpXdIlNBvJGwAAQApI3gAAgOyw2iQAAABJkrwBAADZYbVJAAAAkiR5AwAAssOcNwAAAJKkeQMAAEgBj00CAADZUbRJNwlqf+BXY5uf3tzg2Po7P0+6LFLuxBPGxOLn5saauhfiwQduj0F7fybpkkg5Y4pKM6aoNGOKtNO8pcS6pS/H2+d9ff3xzpT/TLokUuyoow6Niy+aFOf9+JIYtM/wWLjoqfjjnb+NmppuSZdGShlTVJoxRaUZU1uQUrHljhZW0ebtiSeeqOTt+LDiuiitWbX+iHfeSroiUuz0U4+Pq6deH9de9/t4+unn46TxZ8Y777wb48YenXRppJQxRaUZU1SaMUUWbHbz9tZbb8WVV14Zn/vc52LAgAGVqImNaNO9d2z9n1fH1t/7ZeSPPi1y23ZPuiRSql27drHXXv1j1uw/rT9XKpVi1uwHYvDggQlWRloZU1SaMUWlGVNbmGKx5Y4W1uTmbc6cOTFmzJjo3bt3XHzxxTFs2LCYO3duJWvj/1v3ynPx3u9/Ee9NPS8Kt14Zbbr2iA4nnB/RfqukSyOFunfvGlVVVbF82coG55cvXxG9etYkVBVpZkxRacYUlWZMkRVlrTa5dOnSmD59ekydOjXq6uriK1/5ShQKhbj11lujX79+jb5PoVCIQqHQ4Nza+nWRr2pbTjlbjHXP/vkfL5a+FO++/Fx0nPjrqBowJOrnzUquMAAAaG1s0h0xevTo6NOnTyxatCguu+yyWLJkSfziF79o0ofW1tZGdXV1g+Nnc59r0r22SO+9E8UVr0ebbr2SroQUWrnyjaivr48ePRs+etujR00sXbYioapIM2OKSjOmqDRjiqxodPN21113xTe+8Y0455xzYuTIkdG2bdNTsokTJ8bq1asbHN8Z/Kkm32+L036raNOtZ5Tq3ky6ElJo7dq1sWDBohg2dL/153K5XAwbul/MnTs/wcpIK2OKSjOmqDRjaguT4TlvjX5s8oEHHoipU6fGwIED49Of/nR87Wtfi6OPbtrqPPl8PvL5fINzazwy+ZHajxwT9U/Ni9KqFZHr3DXaH3R0RLEYaxc+kHRppNSll18V06ZeGvMXLIp58/4cp5x8fHTs2CGmXzsj6dJIKWOKSjOmqDRjiixodPM2ePDgGDx4cFx22WUxY8aMuOaaa2LChAlRLBbjnnvuiR133DE6derUnLVusXLV3WKr/zMhclt3itLbdbHuxafjnclnRrxdl3RppNSNN94WNd27xtk/OiN69aqJhQufjJGjjo3ly1du+mLYCGOKSjOmqDRjaguSQCLWUnKlUqnU1IufffbZmDp1avzXf/1XrFq1Kg466KC47bbbmnSvNd//clPLgI3a9tKHky4BACC16t9/LekSmuS9P/1Xi33WVvt/rcU+K2Iz93nr06dPXHjhhfHqq6/G7373u0rVBAAA0CSl0roWO1raZm/SHRHRtm3bOPzww5ucugEAAPDxytrnDQAAoFXL8Jy3iiRvAAAANC/JGwAAkB0lyRsAAAAJ0rwBAACkgMcmAQCA7LBgCQAAAEmSvAEAANlhwRIAAACSJHkDAACyw5w3AAAAkiR5AwAAssOcNwAAADbHa6+9Fscee2x069YtOnToEHvuuWc8+uijjb5e8gYAAGRHK53z9uabb8aQIUNi6NChcdddd0VNTU08//zz0aVLl0bfQ/MGAADQzH7605/GjjvuGNOmTVt/bpdddinrHh6bBAAAsqNYbLGjUChEXV1dg6NQKGy0rNtuuy323nvvOOqoo6JHjx7x2c9+Nq666qqyfjXNGwAAQBPU1tZGdXV1g6O2tnaj7/3LX/4SU6ZMid122y3uvvvuOPHEE+OUU06Ja6+9ttGflyuVSqVKFb851nz/y0mXQMZse+nDSZcAAJBa9e+/lnQJTfLuHZe02Ge1OWj8BklbPp+PfD6/wXvbt28fe++9dzz44IPrz51yyikxb968eOihhxr1eea8AQAANMFHNWob07t37+jXr1+Dc5/+9KfjpptuavTnad4AAIDsaKWrTQ4ZMiSeffbZBueee+652HnnnRt9D3PeAAAAmtnpp58ec+fOjZ/85CexePHiuP766+PKK6+M8ePHN/oekjcAACA7Sq0zeRs0aFDccsstMXHixDj33HNjl112icsuuyyOOeaYRt9D8wYAANACRo0aFaNGjWry9Zo3AAAgO1rpnLdKMOcNAAAgBTRvAAAAKeCxSQAAIDta6YIllSB5AwAASAHJGwAAkB0WLAEAACBJkjcAACA7JG8AAAAkSfIGAABkR6mUdAXNRvIGAACQApI3AAAgO8x5AwAAIEmSNwAAIDskbwAAACRJ8gYAAGRHSfIGAABAgiRvAABAdpjzBgAAQJIkbwAAQHaUSklX0GwkbwAAACmgeQMAAEgBj00CAADZYcESAAAAktRqkrdtL3046RLImLdu+W7SJZAhnY64KOkSAIDGkLwBAACQpFaTvAEAAGy2kuQNAACABEneAACAzCgVbdINAABAgiRvAABAdlhtEgAAgCRJ3gAAgOyw2iQAAABJkrwBAADZYbVJAAAAkiR5AwAAssNqkwAAACRJ8gYAAGSH5A0AAIAkad4AAABSwGOTAABAdpRsFQAAAECCJG8AAEB2WLAEAACAJEneAACA7Cia8wYAAECCJG8AAEB2lMx5AwAAIEGSNwAAIDvMeQMAACBJkjcAACAzSvZ5AwAAIEmSNwAAIDvMeQMAACBJkjcAACA77PMGAABAkiRvAABAdpjzBgAAQJIkbwAAQHbY5w0AAIAkad4AAABSwGOTAABAdliwBAAAgCRJ3gAAgOywSTcAAABJ0rwBAADZUSy13FGGs88+O3K5XIOjb9++Zd3DY5MAAAAtYPfdd4977713/euqqvLaMc0bAACQGaVWvEl3VVVV9OrVq8nXe2wSAACgCQqFQtTV1TU4CoXCR77/+eefj+222y4++clPxjHHHBMvv/xyWZ+neQMAALKjBee81dbWRnV1dYOjtrZ2o2Xts88+MX369Jg5c2ZMmTIl/vrXv8b+++8fb731VqN/tVypVGoVu9hVtd8+6RLImLdu+W7SJZAhnY64KOkSAKBF1b//WtIlNMma73+5xT6r3bm/2yBpy+fzkc/nN3ntqlWrYuedd45LLrkkvvGNbzTq8yRvKXLiCWNi8XNzY03dC/HgA7fHoL0/k3RJpNSyVWvirN/cG1/4wTWxz/eujH+7cEY8+crypMsi5XxHUWnGFJVmTG0hWjB5y+fz0blz5wZHYxq3iIhtt902PvWpT8XixYsb/atp3lLiqKMOjYsvmhTn/fiSGLTP8Fi46Kn4452/jZqabkmXRsrUvVOIsb+4Naratokrjh8ZN3//6Jhw2Oejc4fGfdHAxviOotKMKSrNmKK1WbNmTbzwwgvRu3fvRl/jscmUePCB22Peowvj1NN+EBERuVwuXvzLvJj8y2lx4UWTE66udfLY5MZdfsfceOyvr8e0k49IupRU8djkx/MdRaUZU1SaMVW+1D42ecZhLfZZ21z8341+7xlnnBGjR4+OnXfeOZYsWRKTJk2Kxx57LJ566qmoqalp1D0kbynQrl272Guv/jFr9p/WnyuVSjFr9gMxePDABCsjje5/8sXot2OPOOPau2Poj6bFV392Y9z00FNJl0WK+Y6i0owpKs2YojV49dVX49///d+jT58+8ZWvfCW6desWc+fObXTjFtHEfd7+9re/Rbduf4+YX3nllbjqqqvi3XffjUMPPTT233//ptySj9G9e9eoqqqK5ctWNji/fPmK6NvnXxKqirR69W91ceODT8axX+gf3/ziXvHEKyviwlseiHZVbeLQQX2TLo8U8h1FpRlTVJoxtYUptooHCzdwww03bPY9ymreHn/88Rg9enS88sorsdtuu8UNN9wQw4cPj7fffjvatGkTl156afzhD3+Iww8//GPvUygUNliVpVQqRS6XK/sXAMpTLJWi3441ccrIwRER0XeHmnjh9TfiDw8+pXkDAGjFynps8nvf+17sueeeMWfOnDjggANi1KhRMXLkyFi9enW8+eab8a1vfSsuuOCCTd5nY/shlIqN399gS7Ny5RtRX18fPXp2b3C+R4+aWLpsRUJVkVY1nbeOf+nZpcG5XXpuG6+/uSahikg731FUmjFFpRlTW5ZSsdRiR0srq3mbN29enH/++TFkyJC4+OKLY8mSJXHSSSdFmzZtok2bNnHyySfHM888s8n7TJw4MVavXt3gyLXp1ORfIuvWrl0bCxYsimFD91t/LpfLxbCh+8XcufMTrIw0GvCJXvHi8lUNzr20YnX07rpNMgWRer6jqDRjikozpsiKsh6bfOONN6JXr14REbHNNttEx44do0uXf/wX/C5dujRqh/CNbVznkcmPd+nlV8W0qZfG/AWLYt68P8cpJx8fHTt2iOnXzki6NFLm2C8MiLE/vyWuvnd+HDxg13ji5WVx09yn4odHfSHp0kgx31FUmjFFpRlTZEHZC5b8c5Ol6WoZN954W9R07xpn/+iM6NWrJhYufDJGjjo2li9fuemL4UP22KlHXDLukPj5nQ/Hlf8zP7bv2im+e9iQGDnwU0mXRor5jqLSjCkqzZjagrTSBUsqoax93tq0aRMjRoxYn5rdfvvtMWzYsOjYsWNE/H0hkpkzZ8a6devKLsQ+b1Safd6oJPu8AbClSes+b2+dMqrFPqvTz+9osc+KKDN5GzNmTIPXxx577AbvOe644zavIgAAgKYqFpOuoNmU1bxNmzatueoAAADgYzRpk24AAIBWKcNz3sraKgAAAIBkSN4AAIDskLwBAACQJMkbAACQGWXshJY6kjcAAIAUkLwBAADZYc4bAAAASZK8AQAA2SF5AwAAIEmSNwAAIDNKkjcAAACSJHkDAACyQ/IGAABAkiRvAABAdhSTLqD5SN4AAABSQPMGAACQAh6bBAAAMsNWAQAAACRK8gYAAGSH5A0AAIAkSd4AAIDssFUAAAAASZK8AQAAmWG1SQAAABIleQMAALLDnDcAAACSJHkDAAAyw5w3AAAAEiV5AwAAssOcNwAAAJIkeQMAADKjJHkDAAAgSZI3AAAgOyRvAAAAJEnzBgAAkAIemwQAADLDgiUAAAAkSvIGAABkh+QNAACAJEneAACAzDDnDQAAgERJ3gAAgMyQvAEAAJAoyRsAAJAZkjcAAAASJXkDAACyo5RLuoJmo3kjszodcVHSJZAhb93y3aRLIGN8R1Fpe3XfNekSgGameQMAADLDnDcAAAASJXkDAAAyo1TM7pw3yRsAAEAKSN4AAIDMMOcNAACAREneAACAzChleJ83yRsAAEAKaN4AAABSwGOTAABAZliwBAAAgIq54IILIpfLxWmnndboayRvAABAZqRhk+558+bFr3/96+jfv39Z10neAAAAWsiaNWvimGOOiauuuiq6dOlS1rWaNwAAIDNKpZY7CoVC1NXVNTgKhcLH1jd+/PgYOXJkHHjggWX/bpo3AACAJqitrY3q6uoGR21t7Ue+/4YbbogFCxZ87Hs+jjlvAABAZrTknLeJEyfGhAkTGpzL5/Mbfe8rr7wSp556atxzzz2x1VZbNenzNG8AAABNkM/nP7JZ+2fz58+P5cuXx1577bX+3Lp162LOnDlxxRVXRKFQiLZt237sPTRvAABAZrTW1Sa/+MUvxuOPP97g3Lhx46Jv377x/e9/f5ONW4TmDQAAoNl16tQp9thjjwbnOnbsGN26ddvg/EfRvAEAAJlRKiVdQfPRvAEAACTgvvvuK+v9mjcAACAzWuuct0qwzxsAAEAKSN4AAIDMKJUkbwAAACRI8gYAAGRGqZh0Bc1H8gYAAJACmjcAAIAU8NgkAACQGUULlgAAAJAkyRsAAJAZtgoAAAAgUZI3AAAgM0pFyRsAAAAJkrwBAACZUSolXUHzkbwBAACkgOQNAADIDHPeAAAASJTkDQAAyIyifd4AAABIkuQNAADIjJLkDQAAgCRJ3gAAgMywzxsAAACJkrwBAACZYbVJAAAAEiV5AwAAMsNqk7QKJ54wJhY/NzfW1L0QDz5wewza+zNJl0SKGU9U0rJVa+Ks39wbX/jBNbHP966Mf7twRjz5yvKkyyLlfE9RKZ/dp3/87NrauHPBTfHIkvvjC8P3S7okaBLNW0ocddShcfFFk+K8H18Sg/YZHgsXPRV/vPO3UVPTLenSSCHjiUqqe6cQY39xa1S1bRNXHD8ybv7+0THhsM9H5w75pEsjxXxPUUlbbd0hnn9ycVx01mVJlwKbRfOWEqefenxcPfX6uPa638fTTz8fJ40/M955590YN/bopEsjhYwnKmna7D9Hr207xrn/Piz23LlnbN+tc3y+z46xY/fqpEsjxXxPUUkP/e/D8asLp8Z9M/+UdCm0gFKp5Y6WVlbzNnv27OjXr1/U1dVt8LPVq1fH7rvvHn/6k/9TVFq7du1ir736x6zZ//hnWyqVYtbsB2Lw4IEJVkYaGU9U2v1Pvhj9duwRZ1x7dwz90bT46s9ujJseeirpskgx31MAG1dW83bZZZfF8ccfH507d97gZ9XV1fGtb30rLrnkkooVx9917941qqqqYvmylQ3OL1++Inr1rEmoKtLKeKLSXv1bXdz44JOxU/fqmPIfo+Koz+8eF97yQNw275mkSyOlfE8Bm6NYyrXY0dLKat4WLlwYw4cP/8ifH3zwwTF//vxN3qdQKERdXV2Do5TlrdABMqxYKkXfHbrHKSMHR98dauLf9u0XXx7cL/7woPQNACqprOZt2bJl0a5du4/8eVVVVaxYsWKT96mtrY3q6uoGR6n4VjmlbFFWrnwj6uvro0fP7g3O9+hRE0uXbfqfN3yY8USl1XTeOv6lZ5cG53bpuW28/uaahCoi7XxPAZujVMq12NHSymrett9++3jiiSc+8ueLFi2K3r17b/I+EydOjNWrVzc4cm06lVPKFmXt2rWxYMGiGDb0H8va5nK5GDZ0v5g7d9NJJ3yY8USlDfhEr3hx+aoG515asTp6d90mmYJIPd9TABtX1ibdX/rSl+KHP/xhDB8+PLbaaqsGP3v33Xdj0qRJMWrUqE3eJ5/PRz7fcAnpXC67m+lVwqWXXxXTpl4a8xcsinnz/hynnHx8dOzYIaZfOyPp0kgh44lKOvYLA2Lsz2+Jq++dHwcP2DWeeHlZ3DT3qfjhUV9IujRSzPcUldRh6w6xwy7br3+93Y69Y7fdd426VXWx7DV7UmZNEnPRWkpZzdsPfvCDuPnmm+NTn/pUfPvb344+ffpERMQzzzwTkydPjnXr1sV//ud/NkuhW7obb7wtarp3jbN/dEb06lUTCxc+GSNHHRvLl6/c9MXwT4wnKmmPnXrEJeMOiZ/f+XBc+T/zY/uuneK7hw2JkQM/lXRppJjvKSrp0wP6xK9uunz969PP+XZERNwx46449/QLkioLypYrlblSyEsvvRQnnnhi3H333esXGcnlcnHIIYfE5MmTY5dddmlSIVXtt9/0mwAS8tYt3026BDKm0xEXJV0CGbNX912TLoGMeWTJ/UmX0CRzt/tyi33W4CU3t9hnRZSZvEVE7LzzzvHHP/4x3nzzzVi8eHGUSqXYbbfdokuXLpu+GAAAgCYpu3n7QJcuXWLQoEGVrAUAAGCzZHnOW1mrTQIAAJCMJidvAAAArU0S+6+1FMkbAABACkjeAACAzCgmXUAzkrwBAACkgOQNAADIjFKY8wYAAECCNG8AAAAp4LFJAAAgM4qlpCtoPpI3AACAFJC8AQAAmVG0YAkAAABJkrwBAACZYasAAAAAEiV5AwAAMqOYdAHNSPIGAACQApI3AAAgM8x5AwAAIFGSNwAAIDPMeQMAACBRkjcAACAzJG8AAAAkSvIGAABkhtUmAQAASJTkDQAAyIxidoM3yRsAAEAaSN4AAIDMKJrzBgAAQFNNmTIl+vfvH507d47OnTvHvvvuG3fddVdZ99C8AQAANLMddtghLrjggpg/f348+uijMWzYsDjssMPiySefbPQ9PDYJAABkRinpAj7C6NGjG7w+//zzY8qUKTF37tzYfffdG3UPzRsAAEATFAqFKBQKDc7l8/nI5/Mfe926devixhtvjLfffjv23XffRn+exyYBAIDMKLbgUVtbG9XV1Q2O2traj6zt8ccfj2222Sby+XyccMIJccstt0S/fv0a/btJ3gAAAJpg4sSJMWHChAbnPi5169OnTzz22GOxevXq+MMf/hBjxoyJ+++/v9ENnOYNAADIjGKu5bYKaMwjkh/Wvn372HXXXSMiYuDAgTFv3ry4/PLL49e//nWjrvfYJAAAQAKKxeIGc+Y+juQNAADIjNa62uTEiRNjxIgRsdNOO8Vbb70V119/fdx3331x9913N/oemjcAAIBmtnz58jjuuOPi9ddfj+rq6ujfv3/cfffdcdBBBzX6Hpo3AAAgM4pJF/ARpk6dutn3MOcNAAAgBSRvAABAZhRbbrHJFid5AwAASAHJGwAAkBnFyG70JnkDAABIAckbAACQGa11n7dKkLwBAACkgOQNAADIDKtNAgAAkKhWk7ydvN3+SZdAxvxiyZ+SLoEM6XTERUmXQMa86zuKChvSf1zSJQDNrNU0bwAAAJurmHQBzchjkwAAACkgeQMAADLDVgEAAAAkSvIGAABkhq0CAAAASJTkDQAAyAyrTQIAAJAoyRsAAJAZkjcAAAASJXkDAAAyo2S1SQAAAJIkeQMAADLDnDcAAAASJXkDAAAyQ/IGAABAoiRvAABAZpSSLqAZSd4AAABSQPIGAABkRtE+bwAAACRJ8wYAAJACHpsEAAAyw1YBAAAAJEryBgAAZIbkDQAAgERJ3gAAgMywSTcAAACJkrwBAACZYZNuAAAAEiV5AwAAMsNqkwAAACRK8gYAAGSG1SYBAABIlOQNAADIjGKGszfJGwAAQApI3gAAgMyw2iQAAACJkrwBAACZkd0Zb5I3AACAVNC8AQAApIDHJgEAgMywYAkAAACJkrwBAACZUcwlXUHzkbwBAACkgOQNAADIjGKGNwuQvAEAAKSA5A0AAMiM7OZukrdU+OJJh8Xp/31+1D4xLc599Nfx9Su/EzWf7J10WaTciSeMicXPzY01dS/Egw/cHoP2/kzSJZFyxhSVcvCRY2KPISM2OH78s8lJl0ZKfXaf/vGza2vjzgU3xSNL7o8vDN8v6ZKgSTRvKfAv+3w6Hviv/4nLj/hh/Opr50fbqrZxwnVnRfsO+aRLI6WOOurQuPiiSXHejy+JQfsMj4WLnoo/3vnbqKnplnRppJQxRSXdcPXlcd9tv11/XHXZTyIi4uCh+ydcGWm11dYd4vknF8dFZ12WdCm0gGILHi1N85YCV465IOb94f5Y+vyrseTpl+P6M6ZE1x1qYoc9d0m6NFLq9FOPj6unXh/XXvf7ePrp5+Ok8WfGO++8G+PGHp10aaSUMUUlde2ybXTv1nX9cf//fTh23L53DPrsnkmXRko99L8Px68unBr3zfxT0qXAZim7eSsWi3HNNdfEqFGjYo899og999wzDj300LjuuuuiVMryE6atR4dOW0dExDur1iRcCWnUrl272Guv/jFr9j/+gJVKpZg1+4EYPHhggpWRVsYUzWnt2rVxx//8bxwx8uDI5TK8eRNQMcUotdjR0spq3kqlUhx66KHxzW9+M1577bXYc889Y/fdd4+XXnopxo4dG0cccURz1cn/l8vl4vAfjYm/zHsmlj73atLlkELdu3eNqqqqWL5sZYPzy5eviF49axKqijQzpmhOs+Y8FG+tWROHf+mgpEsBSFxZq01Onz495syZE7NmzYqhQ4c2+Nns2bPj8MMPj+uuuy6OO+64j71PoVCIQqHQ4Fx9aV1U5dqWU84W6cjzvh69++wYP/+3SUmXAgDN7uY77o79Bu8dPcyfBBopy88ClpW8/e53v4uzzjprg8YtImLYsGFx5plnxm9/+9tN3qe2tjaqq6sbHPNWP11OKVukL58zLvoN2ysmH31urF76RtLlkFIrV74R9fX10aNn9wbne/SoiaXLViRUFWlmTNFclixdFnMffSyOHD086VIAWoWymrdFixbF8OEf/QU6YsSIWLhw4SbvM3HixFi9enWDY1D1p8spZYvz5XPGxZ6HDIpf/p/z4o1X/csQTbd27dpYsGBRDBv6j2WSc7lcDBu6X8ydOz/BykgrY4rmcsud90TXLtXxr/t+LulSgBTJ8mqTZT02+cYbb0TPnj0/8uc9e/aMN998c5P3yefzkc83XObeI5Mf7cjzvh4DDxsSU4+/OApvvxudaqojIuK9undibWFtwtWRRpdeflVMm3ppzF+wKObN+3OccvLx0bFjh5h+7YykSyOljCkqrVgsxq133hOHjTgwqqr8OwKbp8PWHWKHXbZf/3q7HXvHbrvvGnWr6mLZa8sTrAzKU1bztm7duqiq+uhL2rZtG/X19ZtdFA3t97WDIyLi2zMaznO7/owpMe8P9ydREil34423RU33rnH2j86IXr1qYuHCJ2PkqGNj+fKVm74YNsKYotIemvfneH3Z8jhi5MFJl0IGfHpAn/jVTZevf336Od+OiIg7ZtwV555+QVJl0UySWAWyMWpra+Pmm2+OZ555Jjp06BCf//zn46c//Wn06dOn0ffIlcpY379NmzYxYsSIDVKzDxQKhZg5c2asW7eu0QV84PRP2AuIyvrFEnu5AK3Xu76jqLAh/cclXQIZ88iSdIYEE1qwr7jkxRsa/d7hw4fH0UcfHYMGDYr6+vo466yz4oknnoinnnoqOnbs2Kh7lJW8jRkzZpPv2dRKkwAAAM2ldeZuETNnzmzwevr06dGjR4+YP39+/Ou//muj7lFW8zZt2rRy3g4AAJBZG9sCbWPre2zM6tWrIyKia9eujf68slabBAAA4O82tgVabW3tJq8rFotx2mmnxZAhQ2KPPfZo9OeVlbwBAAC0Zi25hP/EiRNjwoQJDc41JnUbP358PPHEE/HAAw+U9XmaNwAAgCZo7COSH/btb3877rjjjpgzZ07ssMMOZV2reQMAADKj1EqXLCmVSnHyySfHLbfcEvfdd1/ssssuZd9D8wYAANDMxo8fH9dff33893//d3Tq1CmWLl0aERHV1dXRoUOHRt1D8wYAAGRGS855K8eUKVMiIuKAAw5ocH7atGkxduzYRt1D8wYAANDMSqXNf5xT8wYAAGRGsZXOeasE+7wBAACkgOQNAADIjOzmbpI3AACAVJC8AQAAmWHOGwAAAImSvAEAAJnRWvd5qwTJGwAAQApI3gAAgMwomfMGAABAkiRvAABAZpjzBgAAQKI0bwAAACngsUkAACAzLFgCAABAoiRvAABAZliwBAAAgERJ3gAAgMwolsx5AwAAIEGSNwAAIDOym7tJ3gAAAFJB8gYAAGRGMcPZm+QNAAAgBSRvAABAZpQkbwAAACRJ8gYAAGRGMekCmpHkDQAAIAUkbwAAQGZYbRIAAIBESd4AAIDMsNokAAAAiZK8AQAAmWG1SQAAABKleQMAAEgBj00CAACZUSpZsAQAAIAESd4AAIDMsEk3AAAAiZK8AQAAmZHlrQJaTfP2iyV/SroEAGgxHbbbP+kSyJi/9O+bdAlAM2s1zRsAAMDmKpnzBgAAQJIkbwAAQGZYbRIAAIBESd4AAIDMKJUkbwAAACRI8gYAAGRGlvd5k7wBAACkgOQNAADIDPu8AQAAkCjJGwAAkBn2eQMAACBRmjcAAIAU8NgkAACQGTbpBgAAIFGSNwAAIDMsWAIAAECiJG8AAEBm2KQbAACAREneAACAzChabRIAAIAkSd4AAIDMyG7uJnkDAABIBckbAACQGfZ5AwAAIFGSNwAAIDMkbwAAADTZnDlzYvTo0bHddttFLpeLW2+9tex7aN4AAIDMKJVKLXaU4+23344BAwbE5MmTm/y7eWwSAACgmY0YMSJGjBixWffQvAEAAJnRknPeCoVCFAqFBufy+Xzk8/lm+TyPTQIAADRBbW1tVFdXNzhqa2ub7fMkbwAAQGaUWjB5mzhxYkyYMKHBueZK3SI0bwAAAE3SnI9IbozHJgEAAFJA8gYAAGRGuUv4t5Q1a9bE4sWL17/+61//Go899lh07do1dtppp0bdQ/MGAADQzB599NEYOnTo+tcfzJUbM2ZMTJ8+vVH30LwBAACZ0ZJbBZTjgAMO2OxU0Jw3AACAFJC8AQAAmdFa57xVguQNAAAgBSRvAABAZrTWOW+VIHkDAABIAckbAACQGSXJGwAAAEmSvAEAAJlRtNokAAAASZK8AQAAmWHOGwAAAInSvKXIiSeMicXPzY01dS/Egw/cHoP2/kzSJZFixhOVZkxRacYUzaXzmKNjp0dnxbYTTkq6FJpBsVRqsaOlad5S4qijDo2LL5oU5/34khi0z/BYuOip+OOdv42amm5Jl0YKGU9UmjFFpRlTNJf2/frENl8eFe8/90LSpUDZNG8pcfqpx8fVU6+Pa6/7fTz99PNx0vgz45133o1xY49OujRSyHii0owpKs2YojnkOmwV3c47K/52/iVRfOutpMuhmZRa8H8tTfOWAu3atYu99uofs2b/af25UqkUs2Y/EIMHD0ywMtLIeKLSjCkqzZiiuXT5/qnx7v+dG4VHFiRdCjRJWc3bl770pVi9evX61xdccEGsWrVq/eu//e1v0a9fv03ep1AoRF1dXYOjlOH9GDZX9+5do6qqKpYvW9ng/PLlK6JXz5qEqiKtjCcqzZii0owpmsPWBw+N9n13jVVXXJ10KdBkZTVvd999dxQKhfWvf/KTn8Qbb7yx/nV9fX08++yzm7xPbW1tVFdXNzhKRdE1AACV17ZnTXT5zvj42w9qI95fm3Q5NLMsL1hS1j5v/5yONTUtmzhxYkyYMKHBuS7d+jbpXluClSvfiPr6+ujRs3uD8z161MTSZSsSqoq0Mp6oNGOKSjOmqLT2fT8Vbbt1iV6/+dX6c7mqtpH/bP/o9JXD45XPD48oFhOsEBonkTlv+Xw+Onfu3ODI5XJJlJIKa9eujQULFsWwofutP5fL5WLY0P1i7tz5CVZGGhlPVJoxRaUZU1Tae/MWxOtf/UYsPeY/1h+FJ5+Jd2bOiqXH/IfGLWOyvGBJWclbLpfboMnSdLWMSy+/KqZNvTTmL1gU8+b9OU45+fjo2LFDTL92RtKlkULGE5VmTFFpxhSVVHrn3Vj7wosNz733XqxbVbfBeWjNyn5scuzYsZHP5yMi4r333osTTjghOnbsGBHRYD4clXXjjbdFTfeucfaPzohevWpi4cInY+SoY2P58pWbvhj+ifFEpRlTVJoxBTRVEnPRWkquVMbEtXHjxjXqfdOmTSu7kKr225d9DQAAf/eX/tYPoLJ2enRW0iU0yb9036vFPuuFlS277URZyVtTmjIAAICWksRctJZik24AAIAUKCt5AwAAaM1KpeyuHip5AwAASAHJGwAAkBlFc94AAABIkuQNAADIjDJ2QksdyRsAAEAKSN4AAIDMMOcNAACAREneAACAzDDnDQAAgERJ3gAAgMwoSt4AAABIkuYNAAAgBTw2CQAAZEbJVgEAAAAkSfIGAABkhq0CAAAASJTkDQAAyIyiOW8AAAAkSfIGAABkhjlvAAAAJEryBgAAZEZR8gYAAECSJG8AAEBmmPMGAABAoiRvAABAZtjnDQAAgERJ3gAAgMww5w0AAIBESd4AAIDMsM8bAAAAidK8AQAApIDHJgEAgMwo2SoAAACAJEneAACAzLBgCQAAAImSvAEAAJlhk24AAAASJXkDAAAyw2qTAAAAJEryBgAAZIY5bwAAACRK8wYAAGRGqVRqsaMpJk+eHJ/4xCdiq622in322SceeeSRRl+reQMAAGgBM2bMiAkTJsSkSZNiwYIFMWDAgDjkkENi+fLljbpe8wYAAGRGqQWPcl1yySVx/PHHx7hx46Jfv37xq1/9Krbeeuu45pprGnW95g0AAKAJCoVC1NXVNTgKhcJG3/v+++/H/Pnz48ADD1x/rk2bNnHggQfGQw891KjPazWrTda//1rSJbR6hUIhamtrY+LEiZHP55Muhwwwpqg0Y4pKM6aoJONpy9CSfcXZZ58d55xzToNzkyZNirPPPnuD965cuTLWrVsXPXv2bHC+Z8+e8cwzzzTq83KlLK+lmTF1dXVRXV0dq1evjs6dOyddDhlgTFFpxhSVZkxRScYTlVYoFDZI2vL5/Eb/48CSJUti++23jwcffDD23Xff9ee/973vxf333x8PP/zwJj+v1SRvAAAAafJRjdrGdO/ePdq2bRvLli1rcH7ZsmXRq1evRt3DnDcAAIBm1r59+xg4cGDMmjVr/blisRizZs1qkMR9HMkbAABAC5gwYUKMGTMm9t577/jc5z4Xl112Wbz99tsxbty4Rl2veUuRfD4fkyZNMsGWijGmqDRjikozpqgk44mkffWrX40VK1bEj370o1i6dGl85jOfiZkzZ26wiMlHsWAJAABACpjzBgAAkAKaNwAAgBTQvAEAAKSA5g0AACAFNG8p8dBDD0Xbtm1j5MiRSZdCyo0dOzZyudz6o1u3bjF8+PBYtGhR0qWRckuXLo2TTz45PvnJT0Y+n48dd9wxRo8e3WA/G2iMD39PtWvXLnr27BkHHXRQXHPNNVEsFpMujxT65799HxzDhw9PujQoi+YtJaZOnRonn3xyzJkzJ5YsWZJ0OaTc8OHD4/XXX4/XX389Zs2aFVVVVTFq1KikyyLFXnzxxRg4cGDMnj07Lrroonj88cdj5syZMXTo0Bg/fnzS5ZFCH3xPvfjii3HXXXfF0KFD49RTT41Ro0ZFfX190uWRQh/+2/fB8bvf/S7psqAs9nlLgTVr1sSMGTPi0UcfjaVLl8b06dPjrLPOSrosUiyfz0evXr0iIqJXr15x5plnxv777x8rVqyImpqahKsjjU466aTI5XLxyCOPRMeOHdef33333ePrX/96gpWRVh/+ntp+++1jr732isGDB8cXv/jFmD59enzzm99MuELS5sNjCtJK8pYCv//976Nv377Rp0+fOPbYY+Oaa64J2/NRKWvWrInf/OY3seuuu0a3bt2SLocUeuONN2LmzJkxfvz4Bo3bB7bddtuWL4pMGjZsWAwYMCBuvvnmpEsBSITmLQWmTp0axx57bET8PfJfvXp13H///QlXRZrdcccdsc0228Q222wTnTp1ittuuy1mzJgRbdr4SqB8ixcvjlKpFH379k26FLYAffv2jRdffDHpMkihD//t++D4yU9+knRZUBaPTbZyzz77bDzyyCNxyy23REREVVVVfPWrX42pU6fGAQcckGxxpNbQoUNjypQpERHx5ptvxi9/+csYMWJEPPLII7HzzjsnXB1p40kAWlKpVIpcLpd0GaTQh//2faBr164JVQNNo3lr5aZOnRr19fWx3XbbrT9XKpUin8/HFVdcEdXV1QlWR1p17Ngxdt111/Wvr7766qiuro6rrroqfvzjHydYGWm02267RS6Xi2eeeSbpUtgCPP3007HLLrskXQYp9M9/+yCNPCPVitXX18d1110XP/vZz+Kxxx5bfyxcuDC22247KyRRMblcLtq0aRPvvvtu0qWQQl27do1DDjkkJk+eHG+//fYGP1+1alXLF0UmzZ49Ox5//PE48sgjky4FIBGSt1bsjjvuiDfffDO+8Y1vbJCwHXnkkTF16tQ44YQTEqqONCsUCrF06dKI+Ptjk1dccUWsWbMmRo8enXBlpNXkyZNjyJAh8bnPfS7OPffc6N+/f9TX18c999wTU6ZMiaeffjrpEkmZD76n1q1bF8uWLYuZM2dGbW1tjBo1Ko477rikyyOFPvy37wNVVVXRvXv3hCqC8mneWrGpU6fGgQceuNFHI4888si48MILY9GiRdG/f/8EqiPNZs6cGb17946IiE6dOkXfvn3jxhtvNI+SJvvkJz8ZCxYsiPPPPz++853vxOuvvx41NTUxcODADeaYQGN88D1VVVUVXbp0iQEDBsTPf/7zGDNmjMWVaJIP/+37QJ8+fTzyTarkSmaaAwAAtHr+0xUAAEAKaN4AAABSQPMGAACQApo3AACAFNC8AQAApIDmDQAAIAU0bwAAACmgeQMAAEgBzRsAAEAKaN4AAABSQPMGAACQApo3AACAFPh/JjvFdS+oC5gAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "y_pred = []\n", + "y_true = []\n", + "\n", + "for inputs, label in test_set:\n", + " output = spoter(inputs) # Feed Network\n", + "\n", + " output = torch.argmax(output, dim=2) # Get Prediction\n", + " output = output.detach().numpy().tolist()[0] # Convert to list\n", + " \n", + " y_true.extend(output) # Save Truth\n", + "\n", + " y_pred.append(label) # Save Prediction\n", + "\n", + "# constant for classes\n", + "classes = ('A', 'B', 'C', 'D', 'E')\n", + "\n", + "# Build confusion matrix\n", + "cf_matrix = confusion_matrix(y_true, y_pred)\n", + "df_cm = pd.DataFrame(cf_matrix, index = [i for i in classes],\n", + " columns = [i for i in classes])\n", + "plt.figure(figsize = (12,7))\n", + "sn.heatmap(df_cm, annot=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/models/spoter_40.pth b/models/spoter_56.pth similarity index 70% rename from models/spoter_40.pth rename to models/spoter_56.pth index 1e15c13..1b52a7e 100644 Binary files a/models/spoter_40.pth and b/models/spoter_56.pth differ diff --git a/src/datasets/finger_spelling_dataset.py b/src/datasets/finger_spelling_dataset.py index e412864..ec9acb8 100644 --- a/src/datasets/finger_spelling_dataset.py +++ b/src/datasets/finger_spelling_dataset.py @@ -4,8 +4,8 @@ import numpy as np import torch from sklearn.model_selection import train_test_split -from identifiers import LANDMARKS -from keypoint_extractor import KeypointExtractor +from src.identifiers import LANDMARKS +from src.keypoint_extractor import KeypointExtractor class FingerSpellingDataset(torch.utils.data.Dataset): diff --git a/src/datasets/wlasl_dataset.py b/src/datasets/wlasl_dataset.py index 4f8d35f..2c17919 100644 --- a/src/datasets/wlasl_dataset.py +++ b/src/datasets/wlasl_dataset.py @@ -4,8 +4,8 @@ from collections import OrderedDict import numpy as np import torch -from identifiers import LANDMARKS -from keypoint_extractor import KeypointExtractor +from src.identifiers import LANDMARKS +from src.keypoint_extractor import KeypointExtractor class WLASLDataset(torch.utils.data.Dataset): diff --git a/src/keypoint_extractor.py b/src/keypoint_extractor.py index 43fe833..2c8a3e7 100644 --- a/src/keypoint_extractor.py +++ b/src/keypoint_extractor.py @@ -151,25 +151,34 @@ class KeypointExtractor: return results - def normalize_hands(self, dataframe: pd.DataFrame) -> pd.DataFrame: + def normalize_hands(self, dataframe: pd.DataFrame, norm_algorithm: str="minmax") -> pd.DataFrame: """normalize_hand this function normalizes the hand keypoints of a dataframe :param dataframe: the dataframe to normalize :type dataframe: pd.DataFrame + :param norm_algorithm: the normalization algorithm to use, pick from "minmax" and "bohacek" + :type norm_algorithm: str :return: the normalized dataframe :rtype: pd.DataFrame """ - - # normalize left hand - dataframe = self.normalize_hand_helper(dataframe, "left_hand") - # normalize right hand - dataframe = self.normalize_hand_helper(dataframe, "right_hand") + if norm_algorithm == "minmax": + # normalize left hand + dataframe = self.normalize_hand_minmax(dataframe, "left_hand") + # normalize right hand + dataframe = self.normalize_hand_minmax(dataframe, "right_hand") + elif norm_algorithm == "bohacek": + # normalize left hand + dataframe = self.normalize_hand_bohacek(dataframe, "left_hand") + # normalize right hand + dataframe = self.normalize_hand_bohacek(dataframe, "right_hand") + else: + return dataframe return dataframe - def normalize_hand_helper(self, dataframe: pd.DataFrame, hand: str) -> pd.DataFrame: - """normalize_hand_helper this function normalizes the hand keypoints of a dataframe + def normalize_hand_minmax(self, dataframe: pd.DataFrame, hand: str) -> pd.DataFrame: + """normalize_hand_helper this function normalizes the hand keypoints of a dataframe with respect to the minimum and maximum coordinates :param dataframe: the dataframe to normalize :type dataframe: pd.DataFrame @@ -194,9 +203,66 @@ class KeypointExtractor: # calculate the width and height of the bounding box around the hand keypoints bbox_width, bbox_height = max_x - min_x, max_y - min_y + # repeat the center coordinates and bounding box dimensions to match the shape of hand_coords (numpy magic) + center_x, center_y = center_x.reshape(-1, 1, 1), center_y.reshape(-1, 1, 1) + center_coords = np.concatenate((np.tile(center_x, (1, 21, 1)), np.tile(center_y, (1, 21, 1))), axis=2) + + bbox_width, bbox_height = bbox_width.reshape(-1, 1, 1), bbox_height.reshape(-1, 1 ,1) + bbox_dims = np.concatenate((np.tile(bbox_width, (1, 21, 1)), np.tile(bbox_height, (1, 21, 1))), axis=2) + + if np.any(bbox_dims == 0): + return dataframe + # normalize the hand keypoints based on the bounding box around the hand + norm_hand_coords = (hand_coords - center_coords) / bbox_dims + + # flatten the normalized hand keypoints array and replace the original hand keypoints with the normalized hand keypoints in the dataframe + dataframe.iloc[:, hand_columns] = norm_hand_coords.reshape(-1, 42) + + return dataframe + + def normalize_hand_bohacek(self, dataframe: pd.DataFrame, hand: str) -> pd.DataFrame: + """normalize_hand_helper this function normalizes the hand keypoints of a dataframe using the bohacek normalization algorithm + + :param dataframe: the dataframe to normalize + :type dataframe: pd.DataFrame + :param hand: the hand to normalize + :type hand: str + :return: the normalized dataframe + :rtype: pd.DataFrame + """ + # get all columns that belong to the hand (left hand column 66 - 107, right hand column 108 - 149) + hand_columns = np.array([i for i in range(66 + (42 if hand == "right_hand" else 0), 108 + (42 if hand == "right_hand" else 0))]) + + # get the x, y coordinates of the hand keypoints + hand_coords = dataframe.iloc[:, hand_columns].values.reshape(-1, 21, 2) + + # get the min and max x, y coordinates of the hand keypoints + min_x, min_y = np.min(hand_coords[:, :, 0], axis=1), np.min(hand_coords[:, :, 1], axis=1) + max_x, max_y = np.max(hand_coords[:, :, 0], axis=1), np.max(hand_coords[:, :, 1], axis=1) + + # calculate the deltas + width, height = max_x - min_x, max_y - min_y + if width > height: + delta_x = 0.1 * width + delta_y = delta_x + ((width - height) / 2) + else: + delta_y = 0.1 * height + delta_x = delta_y + ((height - width) / 2) + + # Set the starting and ending point of the normalization bounding box + starting_x, starting_y = min_x - delta_x, min_y - delta_y + ending_x, ending_y = max_x + delta_x, max_y + delta_y + + # calculate the center of the bounding box and the bounding box dimensions + bbox_center_x, bbox_center_y = (starting_x + ending_x) / 2, (starting_y + ending_y) / 2 + bbox_width, bbox_height = starting_x - ending_x, starting_y - ending_y + # repeat the center coordinates and bounding box dimensions to match the shape of hand_coords - center_coords = np.tile(np.array([center_x, center_y]), (21, 1)).reshape(-1, 21, 2) - bbox_dims = np.tile(np.array([bbox_width, bbox_height]), (21, 1)).reshape(-1, 21, 2) + center_x, center_y = center_x.reshape(-1, 1, 1), center_y.reshape(-1, 1, 1) + center_coords = np.concatenate((np.tile(bbox_center_x, (1, 21, 1)), np.tile(bbox_center_y, (1, 21, 1))), axis=2) + + bbox_width, bbox_height = bbox_width.reshape(-1, 1, 1), bbox_height.reshape(-1, 1 ,1) + bbox_dims = np.concatenate((np.tile(bbox_width, (1, 21, 1)), np.tile(bbox_height, (1, 21, 1))), axis=2) if np.any(bbox_dims == 0): return dataframe diff --git a/src/train.py b/src/train.py index de50b9a..e27bd7d 100644 --- a/src/train.py +++ b/src/train.py @@ -13,12 +13,12 @@ import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms -from augmentations import MirrorKeypoints -from datasets.finger_spelling_dataset import FingerSpellingDataset -from datasets.wlasl_dataset import WLASLDataset -from identifiers import LANDMARKS -from keypoint_extractor import KeypointExtractor -from model import SPOTER +from src.augmentations import MirrorKeypoints +from src.datasets.finger_spelling_dataset import FingerSpellingDataset +from src.datasets.wlasl_dataset import WLASLDataset +from src.identifiers import LANDMARKS +from src.keypoint_extractor import KeypointExtractor +from src.model import SPOTER def train(): @@ -81,10 +81,7 @@ def train(): if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]): pred_correct += 1 pred_all += 1 - - # if i % 100 == 0: - # print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}") - + if scheduler: scheduler.step(running_loss.item() / len(train_loader)) @@ -107,7 +104,7 @@ def train(): # save checkpoint - if val_acc > top_val_acc: + if val_acc > top_val_acc and epoch > 55: top_val_acc = val_acc top_train_acc = train_acc checkpoint_index = epoch diff --git a/visualize_data.ipynb b/visualize_data.ipynb index 3667993..d4ea50c 100644 --- a/visualize_data.ipynb +++ b/visualize_data.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -14,16 +14,16 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "video_name = 'C!3_20230225181728393157_8PWYR.mp4'" + "video_name = 'A_robbe.mp4' " ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -33,36 +33,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: Created TensorFlow Lite XNNPACK delegate for CPU.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "drawing\n" - ] - }, - { - "data": { - "text/html": [ - "
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "import numpy as np\n", "from IPython.display import HTML\n", @@ -84,20 +57,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from src.model import SPOTER\n", "from src.identifiers import LANDMARKS\n", @@ -112,12 +74,106 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# get average number of frames in test set\n", + "from src.keypoint_extractor import KeypointExtractor\n", + "from src.datasets.finger_spelling_dataset import FingerSpellingDataset\n", + "from src.identifiers import LANDMARKS\n", + "import numpy as np\n", + "\n", + "keypoints_extractor = KeypointExtractor(\"data/fingerspelling/data/\")\n", + "test_set = FingerSpellingDataset(\"data/fingerspelling/data/\", keypoints_extractor, keypoints_identifier=LANDMARKS, subset=\"val\")\n", + "\n", + "frames = []\n", + "labels = []\n", + "for sample, label in test_set:\n", + " frames.append(sample.shape[0])\n", + " labels.append(label)\n", + "\n", + "print(np.mean(frames))\n", + "# get label frequency in the labels list\n", + "from collections import Counter\n", + "\n", + "counter = Counter(labels)\n", + "print(counter)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hand keypoint visualization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "def plot_hand_keypoints(dataframe, hand, frame):\n", + " hand_columns = np.array([i for i in range(66 + (42 if hand == \"right\" else 0), 108 + (42 if hand == \"right\" else 0))])\n", + " \n", + " # get the x, y coordinates of the hand keypoints\n", + " frame_df = dataframe.iloc[frame:frame+1, hand_columns]\n", + " hand_coords = frame_df.values.reshape(21, 2)\n", + " \n", + " x_coords = hand_coords[:, ::2] #Even indices\n", + " y_coords = hand_coords[:, 1::2] #Uneven indices\n", + " \n", + " #Plot the keypoints\n", + " plt.scatter(x_coords, y_coords)\n", + " return frame_df.style" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Set video, hand and frame to display\n", + "video_name = 'A_victor.mp4'\n", + "hand = \"right\"\n", + "frame = 1\n", + "%reload_ext autoreload" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from src.keypoint_extractor import KeypointExtractor\n", + "import numpy as np\n", + "\n", + "#Extract keypoints from requested video\n", + "keypoints_extractor = KeypointExtractor(\"data/fingerspelling/data/\")\n", + "\n", + "\n", + "#Plot the hand keypoints\n", + "df = keypoints_extractor.extract_keypoints_from_video(video_name, normalize=False)\n", + "plot_hand_keypoints(df, hand, frame)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Plot the NORMALIZED hand keypoints\n", + "df = keypoints_extractor.extract_keypoints_from_video(video_name, normalize=True)\n", + "plot_hand_keypoints(df, hand, frame)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -131,9 +187,8 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.9.13" }, - "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" diff --git a/webcam_view.py b/webcam_view.py new file mode 100644 index 0000000..f043c2c --- /dev/null +++ b/webcam_view.py @@ -0,0 +1,129 @@ +import cv2 +import mediapipe as mp +import numpy as np +import torch + +from src.identifiers import LANDMARKS +from src.model import SPOTER + +# Initialize MediaPipe Hands model +holistic = mp.solutions.holistic.Holistic( + min_detection_confidence=0.5, + min_tracking_confidence=0.5, + model_complexity=2 + ) +mp_holistic = mp.solutions.holistic +mp_drawing = mp.solutions.drawing_utils +# Initialize video capture object +cap = cv2.VideoCapture(0) + + +keypoints = [] + +spoter_model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2) +spoter_model.load_state_dict(torch.load('models/spoter_56.pth')) + +m = { + 0: "A", + 1: "B", + 2: "C", + 3: "D", + 4: "E" +} + +while True: + # Read a frame from the webcam + ret, frame = cap.read() + if not ret: + break + + # Convert the frame to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Detect hand landmarks in the frame + results = holistic.process(frame) + + def extract_keypoints(landmarks): + if landmarks: + return [i for landmark in landmarks.landmark for i in [landmark.x, landmark.y]] + + k1 = extract_keypoints(results.pose_landmarks) + k2 = extract_keypoints(results.left_hand_landmarks) + k3 = extract_keypoints(results.right_hand_landmarks) + + if k1 and (k2 or k3): + data = np.array([k1 + (k2 or [0] * 42) + (k3 or [0] * 42)]) + + def normalize_hand(frame, data, hand): + hand_columns = np.array([i for i in range(66 + (42 if hand == "right_hand" else 0), 108 + (42 if hand == "right_hand" else 0))]) + hand_data = np.array(data[0])[hand_columns] + + hand_data = hand_data.reshape(21, 2) + + + min_x, min_y = np.min(hand_data[:, 0]), np.min(hand_data[:, 1]) + max_x, max_y = np.max(hand_data[:, 0]), np.max(hand_data[:, 1]) + + center_x, center_y = (min_x + max_x) / 2, (min_y + max_y) / 2 + + bbox_width, bbox_height = max_x - min_x, max_y - min_y + + if bbox_height == 0 or bbox_width == 0: + return data, frame + + center_coords = np.tile(np.array([center_x, center_y]), (21, 1)).reshape(21, 2) + + hand_data = (hand_data - center_coords) / np.tile(np.array([bbox_width, bbox_height]), (21, 1)).reshape(21, 2) + + # add bouding box to frame + frame = cv2.rectangle(frame, (int(min_x * frame.shape[1]), int(min_y * frame.shape[0])), (int(max_x * frame.shape[1]), int(max_y * frame.shape[0])), (0, 255, 0), 2) + + data[:, hand_columns] = hand_data.reshape(-1, 42) + return data, frame + + data, frame = normalize_hand(frame, data, "left_hand") + data, frame = normalize_hand(frame, data, "right_hand") + + # get values of the landmarks as a list of integers + values = [] + for i in LANDMARKS.values(): + values.append(i*2) + values.append(i*2+1) + filtered = np.array(data[0])[np.array(values)] + + while len(keypoints) >= 8: + keypoints.pop(0) + keypoints.append(filtered) + + if len(keypoints) == 8: + # keypoints to tensor + keypoints_tensor = torch.tensor(keypoints).float() + + # predict + outputs = spoter_model(keypoints_tensor).expand(1, -1, -1) + + # softmax + outputs = torch.nn.functional.softmax(outputs, dim=2) + + # get topk predictions + topk = torch.topk(outputs, k=3, dim=2) + + # show overlay on frame at top right with confidence scores of topk predictions + for i, (label, score) in enumerate(zip(topk.indices[0][0], topk.values[0][0])): + cv2.putText(frame, f"{m[label.item()]} {score.item():.2f}", (frame.shape[1] - 200, 50 + i * 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + + + mp_drawing.draw_landmarks(frame, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) + mp_drawing.draw_landmarks(frame, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) + mp_drawing.draw_landmarks(frame, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) + + # Show the frame + cv2.imshow('MediaPipe Hands', frame) + + # Wait for key press to exit + if cv2.waitKey(5) & 0xFF == 27: + break + +# Release the video capture object and destroy the windows +cap.release() +cv2.destroyAllWindows()