Draft: Multiple prediction methods #2

Open
tihabils wants to merge 15 commits from multiple_prediction_methods into main
30 changed files with 3577 additions and 484 deletions

1
.gitignore vendored
View File

@ -155,3 +155,4 @@ out-img/
converted_models/
*.pth
*.onnx
.devcontainer

20
README2.md Normal file
View File

@ -0,0 +1,20 @@
# Spoter Embeddings
## Creating dataset
First, make a folder where all you're videos are located. When this is done, all keypoints can be extracted from the videos using the following command. This will extract the keypoints and store them in \<path-to-landmarks-folder\>.
```
python3 preprocessing.py extract --videos-folder <path-to-videos-folder> --output-landmark <path-to-landmarks-folder>
```
When this is done, the dataset can be created using the following command:
```
python3 preprocessing.py create --landmarks-dataset <path-to-landmarks-folder> --videos-folder <path-to-videos-folder> --dataset-folder <dataset-output-folder> (--create-new-split --test-size <test-percentage>)
```
The above command generates a train (and val) csv file which includes all the extracted keypoints. These can then be used to train or generates embeddings.
## Creating Embeddings
The embeddings can be created using the following command:
```
python3 export_embeddings.py --checkpoint <path-to-checkpoints-file> --dataset <path-to-dataset-file> --output <embeddings-output-file>
```
The command above generates the embeddings for a given dataset and saves them as a csv file.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,91 @@
embeddings,label_name,labels,embeddings2
"[[ 0.01295076 -0.1074132 0.5111911 0.08452003 -0.7638924 0.6458221
-1.3892679 -1.1791427 -0.30736607 -0.41543546 -0.6358013 0.31411174
0.878936 1.7265923 -0.09298562 0.12005516 0.26995468 -1.5934305
-0.12619524 0.9111336 0.91827893 0.5948979 0.90334046 -0.84089845
-0.29575598 -0.3024254 -0.03074584 1.4402957 -0.22309914 0.54750854
-0.68767273 -0.0665718 ]]",BEDANKEN,0,"[0.012950755655765533, -0.10741320252418518, 0.5111911296844482, 0.08452003449201584, -0.763892412185669, 0.6458221077919006, -1.389267921447754, -1.179142713546753, -0.3073660731315613, -0.41543546319007874, -0.6358013153076172, 0.31411173939704895, 0.8789359927177429, 1.7265923023223877, -0.0929856151342392, 0.12005516141653061, 0.2699546813964844, -1.593430519104004, -0.12619523704051971, 0.9111335873603821, 0.9182789325714111, 0.5948979258537292, 0.9033404588699341, -0.8408984541893005, -0.2957559823989868, -0.30242541432380676, -0.030745841562747955, 1.440295696258545, -0.22309914231300354, 0.5475085377693176, -0.6876727342605591, -0.06657180190086365]"
"[[ 1.0287632 -0.94657993 0.7852507 -0.37169546 0.3614158 0.5360078
-0.9989285 -0.24402924 -0.47437504 -0.17307773 0.9196662 0.5362411
-0.50366765 0.7600024 0.6701926 -0.5288625 -0.38939306 -1.0722619
-1.33836 1.331829 -0.6958481 0.57767504 0.87002045 -0.59203666
-0.57576096 0.70008135 0.8224436 -0.32370126 0.5276206 -0.62150276
-0.9639951 -0.40879178]]",GOED,1,"[1.0287631750106812, -0.9465799331665039, 0.785250723361969, -0.37169545888900757, 0.3614158034324646, 0.536007821559906, -0.9989284873008728, -0.244029238820076, -0.47437503933906555, -0.17307773232460022, 0.9196661710739136, 0.5362411141395569, -0.5036676526069641, 0.7600023746490479, 0.6701925992965698, -0.528862476348877, -0.38939306139945984, -1.072261929512024, -1.3383599519729614, 1.3318289518356323, -0.6958481073379517, 0.5776750445365906, 0.8700204491615295, -0.5920366644859314, -0.5757609605789185, 0.7000813484191895, 0.8224436044692993, -0.32370126247406006, 0.5276206135749817, -0.6215027570724487, -0.963995099067688, -0.40879178047180176]"
"[[-0.5988377 -1.6454532 0.5696761 0.27801913 -0.39687645 1.5192578
-0.57908726 -1.9851501 1.1918951 -1.3280408 0.27390718 0.71642965
-1.0471189 0.9825219 -0.6625729 -0.28641602 2.0109527 -0.97667754
-2.3978348 2.666861 0.15066166 -1.6922158 0.9681514 -1.1413386
0.08027886 -0.6215762 -0.09759905 0.6157277 -0.01341684 0.9806912
-0.09793527 -1.9003383 ]]",GOEDEMIDDAG,2,"[-0.598837673664093, -1.6454532146453857, 0.5696761012077332, 0.27801913022994995, -0.3968764543533325, 1.5192577838897705, -0.5790872573852539, -1.9851500988006592, 1.1918951272964478, -1.3280408382415771, 0.2739071846008301, 0.7164296507835388, -1.047118902206421, 0.9825218915939331, -0.6625729203224182, -0.28641602396965027, 2.0109527111053467, -0.9766775369644165, -2.3978347778320312, 2.666861057281494, 0.1506616622209549, -1.6922158002853394, 0.9681513905525208, -1.141338586807251, 0.08027885854244232, -0.621576189994812, -0.09759905189275742, 0.6157277226448059, -0.013416841626167297, 0.9806911945343018, -0.0979352742433548, -1.9003382921218872]"
"[[-0.11578767 0.28371835 0.8604608 -0.31061298 -0.78153455 -0.24914108
-2.443886 -0.3267159 0.29237527 -0.7879326 -0.12082034 0.32809633
-1.0938975 0.7125962 0.49117383 -0.3119098 0.15369919 0.39569452
-1.18983 1.2498406 -0.38752007 -0.54896545 -0.620718 0.02654967
-0.10994279 -0.39713857 -0.18035033 -1.554728 0.12178666 -0.5698373
1.1097438 -1.2350351 ]]",GOEDEMORGEN,3,"[-0.11578767001628876, 0.2837183475494385, 0.8604608178138733, -0.3106129765510559, -0.7815345525741577, -0.24914108216762543, -2.4438860416412354, -0.326715886592865, 0.29237526655197144, -0.7879325747489929, -0.12082034349441528, 0.32809633016586304, -1.0938974618911743, 0.7125961780548096, 0.4911738336086273, -0.3119097948074341, 0.15369918942451477, 0.3956945240497589, -1.18982994556427, 1.2498406171798706, -0.38752007484436035, -0.5489654541015625, -0.6207180023193359, 0.02654966711997986, -0.10994279384613037, -0.3971385657787323, -0.18035033345222473, -1.5547280311584473, 0.12178666144609451, -0.5698372721672058, 1.1097438335418701, -1.2350350618362427]"
"[[-0.9557147 0.03368077 1.0923759 -0.41077584 -1.6992741 -0.59566665
-1.1562454 0.8824008 -0.25122 0.625515 1.6720039 0.5639413
-0.35894975 -0.05896414 0.84854823 0.4072139 -1.1763096 -1.5435666
-1.5066593 0.66299886 -1.3121625 1.4410669 -0.6341419 -0.4873513
0.8113033 -0.01502099 0.32262453 2.5855515 -0.7294207 0.7582124
-0.1386989 -0.17011487]]",GOEDENACHT,4,"[-0.9557147026062012, 0.03368076682090759, 1.0923758745193481, -0.4107758402824402, -1.6992740631103516, -0.5956666469573975, -1.1562453508377075, 0.8824008107185364, -0.2512199878692627, 0.6255149841308594, 1.6720038652420044, 0.5639412999153137, -0.35894975066185, -0.058964136987924576, 0.8485482335090637, 0.40721389651298523, -1.176309585571289, -1.5435665845870972, -1.5066592693328857, 0.6629988551139832, -1.3121625185012817, 1.441066861152649, -0.6341419219970703, -0.48735129833221436, 0.8113033175468445, -0.015020990744233131, 0.322624534368515, 2.5855515003204346, -0.7294207215309143, 0.7582123875617981, -0.13869890570640564, -0.17011487483978271]"
"[[ 0.13687058 -1.099074 1.1867181 -1.2216619 -0.965039 -0.6957289
-1.6400954 0.7317837 0.5887569 0.18756844 1.4867041 0.63357025
-1.5853169 0.4157976 0.76919526 0.08512082 -1.0937556 -0.07339763
-2.8580015 1.6835626 -1.4538742 1.2302468 -0.49323368 -0.81243044
-0.11378415 -0.09592562 0.31165385 -0.32617894 0.02981896 -0.01485211
1.1880591 -0.40726334]]",GOEDENAVOND,5,"[0.1368705779314041, -1.0990740060806274, 1.1867181062698364, -1.221661925315857, -0.9650390148162842, -0.6957288980484009, -1.6400953531265259, 0.7317836880683899, 0.5887569189071655, 0.18756844103336334, 1.4867041110992432, 0.6335702538490295, -1.5853168964385986, 0.4157975912094116, 0.7691952586174011, 0.08512081950902939, -1.0937556028366089, -0.07339762896299362, -2.858001470565796, 1.6835626363754272, -1.4538742303848267, 1.2302467823028564, -0.49323368072509766, -0.8124304413795471, -0.11378414928913116, -0.09592562168836594, 0.31165385246276855, -0.3261789381504059, 0.029818959534168243, -0.014852114021778107, 1.1880590915679932, -0.4072633385658264]"
"[[ 1.3673804 1.7083405 0.58775777 0.26056585 -0.29101595 -0.69954723
0.45681733 1.6908163 0.02684825 0.7537957 2.2054908 0.28831166
-1.5712252 -1.6869793 0.8009781 -0.51264 -1.3544708 -0.72502786
-0.31009498 -0.23777717 -1.7524906 1.6333568 1.5263942 0.22317217
-0.40576094 2.2136266 1.4030526 -1.0599502 0.8686069 -1.1141618
-0.01899157 -1.2256656 ]]",JA,6,"[1.3673803806304932, 1.7083405256271362, 0.5877577662467957, 0.260565847158432, -0.29101595282554626, -0.6995472311973572, 0.4568173289299011, 1.6908162832260132, 0.02684824913740158, 0.7537956833839417, 2.205490827560425, 0.2883116602897644, -1.5712251663208008, -1.6869792938232422, 0.8009781241416931, -0.5126399993896484, -1.3544708490371704, -0.725027859210968, -0.3100949823856354, -0.23777717351913452, -1.7524906396865845, 1.6333568096160889, 1.526394248008728, 0.2231721729040146, -0.4057609438896179, 2.2136266231536865, 1.403052568435669, -1.0599502325057983, 0.8686069250106812, -1.1141618490219116, -0.01899157091975212, -1.22566556930542]"
"[[ 1.8208723 -0.59625745 0.15060106 -0.23053254 1.0344827 1.7016335
0.1968202 -0.48188782 0.3004466 -0.53336793 0.89704275 0.26869062
-1.0112952 0.0343281 0.32149622 -0.8189871 0.33571526 -0.57059044
-0.7577773 1.4743904 -0.01735806 -0.4116615 2.3637483 -0.6602343
-0.6101709 1.8139238 1.041902 -0.9006022 0.6082014 -0.23616463
-0.62204087 -0.524899 ]]",LINKS,7,"[1.8208723068237305, -0.5962574481964111, 0.15060105919837952, -0.23053254187107086, 1.034482717514038, 1.7016334533691406, 0.1968201994895935, -0.4818878173828125, 0.30044659972190857, -0.533367931842804, 0.8970427513122559, 0.2686906158924103, -1.011295199394226, 0.034328099340200424, 0.32149621844291687, -0.8189870715141296, 0.33571526408195496, -0.5705904364585876, -0.7577772736549377, 1.4743903875350952, -0.01735806092619896, -0.4116615056991577, 2.36374831199646, -0.660234272480011, -0.6101709008216858, 1.8139238357543945, 1.04190194606781, -0.9006022214889526, 0.6082013845443726, -0.2361646294593811, -0.622040867805481, -0.5248990058898926]"
"[[ 1.3968729 2.340378 0.567814 0.5684975 -0.8795973 -1.0090083
-0.01301649 1.9747832 -0.3257468 0.76960254 1.402801 0.39424965
0.06948688 -0.51904166 1.0961802 -0.34335235 -1.8730735 -1.0801809
0.2751789 -0.96998405 -1.9822828 2.1046805 1.3094746 0.5770473
-0.19693638 1.8973799 0.99536693 0.2668684 0.28499565 -0.9838486
-0.19578832 -0.46982569]]",NEE,8,"[1.396872878074646, 2.3403780460357666, 0.5678139925003052, 0.5684974789619446, -0.8795973062515259, -1.0090082883834839, -0.013016488403081894, 1.974783182144165, -0.3257468044757843, 0.7696025371551514, 1.4028010368347168, 0.39424964785575867, 0.06948687881231308, -0.5190416574478149, 1.0961802005767822, -0.343352347612381, -1.8730734586715698, -1.0801808834075928, 0.2751789093017578, -0.9699840545654297, -1.9822827577590942, 2.1046805381774902, 1.3094745874404907, 0.5770472884178162, -0.19693638384342194, 1.8973798751831055, 0.9953669309616089, 0.2668684124946594, 0.2849956452846527, -0.9838485717773438, -0.19578832387924194, -0.4698256850242615]"
"[[ 1.8143647 -0.9315294 0.28706568 -1.0938002 -0.2451038 0.42331484
-0.31653756 0.00268575 0.89822304 1.0639151 1.7406261 1.1707302
-1.6639041 0.30558044 0.22984704 -1.0260515 -0.08818819 -0.14696571
-1.5269523 0.55015475 -0.12936743 0.93951946 1.3917109 -0.62517196
-1.3869016 1.4833041 1.4647726 -1.0285482 -0.3704591 1.4672718
0.40315557 0.04754414]]",RECHTS,9,"[1.8143646717071533, -0.9315294027328491, 0.28706568479537964, -1.0938001871109009, -0.24510380625724792, 0.4233148396015167, -0.3165375590324402, 0.002685748040676117, 0.8982230424880981, 1.0639151334762573, 1.7406260967254639, 1.1707302331924438, -1.663904070854187, 0.3055804371833801, 0.2298470437526703, -1.0260515213012695, -0.08818819373846054, -0.14696571230888367, -1.5269522666931152, 0.5501547455787659, -0.1293674260377884, 0.939519464969635, 1.391710877418518, -0.625171959400177, -1.386901617050171, 1.4833041429519653, 1.4647725820541382, -1.028548240661621, -0.37045910954475403, 1.4672718048095703, 0.4031555652618408, 0.04754413664340973]"
"[[ 1.681777 -0.79441184 0.07110912 -0.01304399 1.110081 2.1115103
-0.27283993 -1.5035654 0.55927217 -0.9903696 0.05718866 0.19891238
-0.7289791 0.71738267 -0.41428873 -0.58389515 1.1087953 -0.6616588
-0.6330258 1.5830203 0.60013187 -1.0640136 2.3917787 -0.70012283
-0.7359694 0.9847692 0.94458187 -0.5915589 0.3662199 0.39359796
-0.6499281 -0.30565363]]",SALUUT,10,"[1.681777000427246, -0.794411838054657, 0.07110912352800369, -0.013043994084000587, 1.1100809574127197, 2.1115102767944336, -0.2728399336338043, -1.5035654306411743, 0.5592721700668335, -0.9903696179389954, 0.05718865618109703, 0.1989123821258545, -0.7289791107177734, 0.7173826694488525, -0.414288729429245, -0.5838951468467712, 1.1087952852249146, -0.6616588234901428, -0.6330258250236511, 1.5830203294754028, 0.6001318693161011, -1.0640136003494263, 2.3917787075042725, -0.7001228332519531, -0.7359694242477417, 0.9847692251205444, 0.9445818662643433, -0.5915588736534119, 0.3662199079990387, 0.39359796047210693, -0.649928092956543, -0.3056536316871643]"
"[[ 0.56773967 -0.6260798 0.23771092 -0.10016712 0.86817515 0.92371875
-0.16313902 -0.3785063 -0.41955388 -0.586288 -0.33712262 -0.07999519
-0.98128295 0.18787864 -0.36432508 -0.06605241 0.00710656 -0.36359936
-0.00642961 1.257384 0.64647824 -1.1618687 1.3707273 -0.46546304
-0.14522405 0.29306158 0.41898865 -0.12311607 0.24604419 -0.48434448
-0.79076517 0.753163 ]]",SLECHT,11,"[0.5677396655082703, -0.626079797744751, 0.23771092295646667, -0.10016711801290512, 0.8681751489639282, 0.9237187504768372, -0.16313901543617249, -0.37850630283355713, -0.41955387592315674, -0.5862879753112793, -0.3371226191520691, -0.07999519258737564, -0.9812829494476318, 0.18787863850593567, -0.36432507634162903, -0.06605241447687149, 0.007106557488441467, -0.36359935998916626, -0.006429608445614576, 1.257383942604065, 0.6464782357215881, -1.161868691444397, 1.370727300643921, -0.4654630422592163, -0.14522404968738556, 0.29306158423423767, 0.4189886450767517, -0.12311606854200363, 0.24604418873786926, -0.484344482421875, -0.7907651662826538, 0.7531629800796509]"
"[[ 0.9251696 -2.8536968 0.6521715 -1.8256427 0.44136596 0.33299872
-1.4674346 -0.7897836 0.47153538 0.1437631 0.7396151 1.0851275
-1.2100328 1.626544 0.69652647 0.0048107 0.13927785 -0.63199776
-3.0121155 2.1738136 -0.7935879 0.2744053 0.281097 -1.0899199
-0.7062637 -0.04824958 0.3436054 -0.69366074 0.22799876 0.8932708
0.51823634 -0.08065546]]",SMAKELIJK,12,"[0.9251695871353149, -2.853696823120117, 0.6521714925765991, -1.825642704963684, 0.44136595726013184, 0.33299872279167175, -1.4674346446990967, -0.7897835969924927, 0.4715353846549988, 0.14376309514045715, 0.7396150827407837, 1.0851274728775024, -1.2100328207015991, 1.6265439987182617, 0.6965264678001404, 0.004810698330402374, 0.139277845621109, -0.6319977641105652, -3.012115478515625, 2.173813581466675, -0.7935879230499268, 0.274405300617218, 0.2810969948768616, -1.089919924736023, -0.7062637209892273, -0.04824957996606827, 0.3436053991317749, -0.6936607360839844, 0.2279987633228302, 0.8932707905769348, 0.5182363390922546, -0.08065545558929443]"
"[[-1.460053 2.9393604 0.5533726 1.3786266 -1.5367306 -1.2867874
-1.2442408 0.62938243 -0.7092637 -1.0278705 -1.3296479 -0.8105969
0.69997054 0.16618343 0.5113248 0.15563956 -0.4815752 -0.10130378
1.4232539 -1.2767068 -0.3229611 0.45140803 -1.3901687 1.0403453
1.454544 -1.308266 -1.145199 0.661388 -0.20519465 -1.0480955
-0.04290716 -0.36275834]]",SORRY,13,"[-1.4600529670715332, 2.9393603801727295, 0.5533726215362549, 1.3786265850067139, -1.5367306470870972, -1.2867873907089233, -1.2442407608032227, 0.6293824315071106, -0.7092636823654175, -1.027870535850525, -1.3296478986740112, -0.8105968832969666, 0.699970543384552, 0.16618342697620392, 0.5113248229026794, 0.15563955903053284, -0.4815751910209656, -0.10130377858877182, 1.4232538938522339, -1.2767068147659302, -0.32296109199523926, 0.4514080286026001, -1.3901686668395996, 1.040345311164856, 1.454543948173523, -1.308266043663025, -1.145198941230774, 0.6613879799842834, -0.2051946520805359, -1.048095464706421, -0.04290715977549553, -0.3627583384513855]"
"[[ 1.0733138 -0.66348886 0.36737278 0.01765811 0.5730918 1.7617474
-0.45580968 -1.0973133 0.4730748 -0.4665998 0.48594972 0.548745
-0.91712314 0.4846773 0.3774526 -0.5996147 0.28750947 -1.0166394
-0.6239797 1.7935088 -0.03666669 -0.51941574 2.045076 -0.9045104
-0.6312211 0.9698636 1.0522215 -0.14772844 0.7406032 0.2901447
-0.48710442 -0.4619298 ]]",TOT-ZIENS,14,"[1.07331383228302, -0.6634888648986816, 0.3673727810382843, 0.017658105120062828, 0.5730918049812317, 1.7617473602294922, -0.45580968260765076, -1.0973132848739624, 0.4730747938156128, -0.46659979224205017, 0.48594972491264343, 0.5487449765205383, -0.9171231389045715, 0.4846773147583008, 0.3774526119232178, -0.599614679813385, 0.2875094711780548, -1.0166393518447876, -0.6239796876907349, 1.793508768081665, -0.036666687577962875, -0.5194157361984253, 2.0450758934020996, -0.9045103788375854, -0.6312211155891418, 0.9698635935783386, 1.0522215366363525, -0.14772844314575195, 0.7406032085418701, 0.2901447117328644, -0.4871044158935547, -0.4619297981262207]"
1 embeddings label_name labels embeddings2
2 [[ 0.01295076 -0.1074132 0.5111911 0.08452003 -0.7638924 0.6458221 -1.3892679 -1.1791427 -0.30736607 -0.41543546 -0.6358013 0.31411174 0.878936 1.7265923 -0.09298562 0.12005516 0.26995468 -1.5934305 -0.12619524 0.9111336 0.91827893 0.5948979 0.90334046 -0.84089845 -0.29575598 -0.3024254 -0.03074584 1.4402957 -0.22309914 0.54750854 -0.68767273 -0.0665718 ]] BEDANKEN 0 [0.012950755655765533, -0.10741320252418518, 0.5111911296844482, 0.08452003449201584, -0.763892412185669, 0.6458221077919006, -1.389267921447754, -1.179142713546753, -0.3073660731315613, -0.41543546319007874, -0.6358013153076172, 0.31411173939704895, 0.8789359927177429, 1.7265923023223877, -0.0929856151342392, 0.12005516141653061, 0.2699546813964844, -1.593430519104004, -0.12619523704051971, 0.9111335873603821, 0.9182789325714111, 0.5948979258537292, 0.9033404588699341, -0.8408984541893005, -0.2957559823989868, -0.30242541432380676, -0.030745841562747955, 1.440295696258545, -0.22309914231300354, 0.5475085377693176, -0.6876727342605591, -0.06657180190086365]
3 [[ 1.0287632 -0.94657993 0.7852507 -0.37169546 0.3614158 0.5360078 -0.9989285 -0.24402924 -0.47437504 -0.17307773 0.9196662 0.5362411 -0.50366765 0.7600024 0.6701926 -0.5288625 -0.38939306 -1.0722619 -1.33836 1.331829 -0.6958481 0.57767504 0.87002045 -0.59203666 -0.57576096 0.70008135 0.8224436 -0.32370126 0.5276206 -0.62150276 -0.9639951 -0.40879178]] GOED 1 [1.0287631750106812, -0.9465799331665039, 0.785250723361969, -0.37169545888900757, 0.3614158034324646, 0.536007821559906, -0.9989284873008728, -0.244029238820076, -0.47437503933906555, -0.17307773232460022, 0.9196661710739136, 0.5362411141395569, -0.5036676526069641, 0.7600023746490479, 0.6701925992965698, -0.528862476348877, -0.38939306139945984, -1.072261929512024, -1.3383599519729614, 1.3318289518356323, -0.6958481073379517, 0.5776750445365906, 0.8700204491615295, -0.5920366644859314, -0.5757609605789185, 0.7000813484191895, 0.8224436044692993, -0.32370126247406006, 0.5276206135749817, -0.6215027570724487, -0.963995099067688, -0.40879178047180176]
4 [[-0.5988377 -1.6454532 0.5696761 0.27801913 -0.39687645 1.5192578 -0.57908726 -1.9851501 1.1918951 -1.3280408 0.27390718 0.71642965 -1.0471189 0.9825219 -0.6625729 -0.28641602 2.0109527 -0.97667754 -2.3978348 2.666861 0.15066166 -1.6922158 0.9681514 -1.1413386 0.08027886 -0.6215762 -0.09759905 0.6157277 -0.01341684 0.9806912 -0.09793527 -1.9003383 ]] GOEDEMIDDAG 2 [-0.598837673664093, -1.6454532146453857, 0.5696761012077332, 0.27801913022994995, -0.3968764543533325, 1.5192577838897705, -0.5790872573852539, -1.9851500988006592, 1.1918951272964478, -1.3280408382415771, 0.2739071846008301, 0.7164296507835388, -1.047118902206421, 0.9825218915939331, -0.6625729203224182, -0.28641602396965027, 2.0109527111053467, -0.9766775369644165, -2.3978347778320312, 2.666861057281494, 0.1506616622209549, -1.6922158002853394, 0.9681513905525208, -1.141338586807251, 0.08027885854244232, -0.621576189994812, -0.09759905189275742, 0.6157277226448059, -0.013416841626167297, 0.9806911945343018, -0.0979352742433548, -1.9003382921218872]
5 [[-0.11578767 0.28371835 0.8604608 -0.31061298 -0.78153455 -0.24914108 -2.443886 -0.3267159 0.29237527 -0.7879326 -0.12082034 0.32809633 -1.0938975 0.7125962 0.49117383 -0.3119098 0.15369919 0.39569452 -1.18983 1.2498406 -0.38752007 -0.54896545 -0.620718 0.02654967 -0.10994279 -0.39713857 -0.18035033 -1.554728 0.12178666 -0.5698373 1.1097438 -1.2350351 ]] GOEDEMORGEN 3 [-0.11578767001628876, 0.2837183475494385, 0.8604608178138733, -0.3106129765510559, -0.7815345525741577, -0.24914108216762543, -2.4438860416412354, -0.326715886592865, 0.29237526655197144, -0.7879325747489929, -0.12082034349441528, 0.32809633016586304, -1.0938974618911743, 0.7125961780548096, 0.4911738336086273, -0.3119097948074341, 0.15369918942451477, 0.3956945240497589, -1.18982994556427, 1.2498406171798706, -0.38752007484436035, -0.5489654541015625, -0.6207180023193359, 0.02654966711997986, -0.10994279384613037, -0.3971385657787323, -0.18035033345222473, -1.5547280311584473, 0.12178666144609451, -0.5698372721672058, 1.1097438335418701, -1.2350350618362427]
6 [[-0.9557147 0.03368077 1.0923759 -0.41077584 -1.6992741 -0.59566665 -1.1562454 0.8824008 -0.25122 0.625515 1.6720039 0.5639413 -0.35894975 -0.05896414 0.84854823 0.4072139 -1.1763096 -1.5435666 -1.5066593 0.66299886 -1.3121625 1.4410669 -0.6341419 -0.4873513 0.8113033 -0.01502099 0.32262453 2.5855515 -0.7294207 0.7582124 -0.1386989 -0.17011487]] GOEDENACHT 4 [-0.9557147026062012, 0.03368076682090759, 1.0923758745193481, -0.4107758402824402, -1.6992740631103516, -0.5956666469573975, -1.1562453508377075, 0.8824008107185364, -0.2512199878692627, 0.6255149841308594, 1.6720038652420044, 0.5639412999153137, -0.35894975066185, -0.058964136987924576, 0.8485482335090637, 0.40721389651298523, -1.176309585571289, -1.5435665845870972, -1.5066592693328857, 0.6629988551139832, -1.3121625185012817, 1.441066861152649, -0.6341419219970703, -0.48735129833221436, 0.8113033175468445, -0.015020990744233131, 0.322624534368515, 2.5855515003204346, -0.7294207215309143, 0.7582123875617981, -0.13869890570640564, -0.17011487483978271]
7 [[ 0.13687058 -1.099074 1.1867181 -1.2216619 -0.965039 -0.6957289 -1.6400954 0.7317837 0.5887569 0.18756844 1.4867041 0.63357025 -1.5853169 0.4157976 0.76919526 0.08512082 -1.0937556 -0.07339763 -2.8580015 1.6835626 -1.4538742 1.2302468 -0.49323368 -0.81243044 -0.11378415 -0.09592562 0.31165385 -0.32617894 0.02981896 -0.01485211 1.1880591 -0.40726334]] GOEDENAVOND 5 [0.1368705779314041, -1.0990740060806274, 1.1867181062698364, -1.221661925315857, -0.9650390148162842, -0.6957288980484009, -1.6400953531265259, 0.7317836880683899, 0.5887569189071655, 0.18756844103336334, 1.4867041110992432, 0.6335702538490295, -1.5853168964385986, 0.4157975912094116, 0.7691952586174011, 0.08512081950902939, -1.0937556028366089, -0.07339762896299362, -2.858001470565796, 1.6835626363754272, -1.4538742303848267, 1.2302467823028564, -0.49323368072509766, -0.8124304413795471, -0.11378414928913116, -0.09592562168836594, 0.31165385246276855, -0.3261789381504059, 0.029818959534168243, -0.014852114021778107, 1.1880590915679932, -0.4072633385658264]
8 [[ 1.3673804 1.7083405 0.58775777 0.26056585 -0.29101595 -0.69954723 0.45681733 1.6908163 0.02684825 0.7537957 2.2054908 0.28831166 -1.5712252 -1.6869793 0.8009781 -0.51264 -1.3544708 -0.72502786 -0.31009498 -0.23777717 -1.7524906 1.6333568 1.5263942 0.22317217 -0.40576094 2.2136266 1.4030526 -1.0599502 0.8686069 -1.1141618 -0.01899157 -1.2256656 ]] JA 6 [1.3673803806304932, 1.7083405256271362, 0.5877577662467957, 0.260565847158432, -0.29101595282554626, -0.6995472311973572, 0.4568173289299011, 1.6908162832260132, 0.02684824913740158, 0.7537956833839417, 2.205490827560425, 0.2883116602897644, -1.5712251663208008, -1.6869792938232422, 0.8009781241416931, -0.5126399993896484, -1.3544708490371704, -0.725027859210968, -0.3100949823856354, -0.23777717351913452, -1.7524906396865845, 1.6333568096160889, 1.526394248008728, 0.2231721729040146, -0.4057609438896179, 2.2136266231536865, 1.403052568435669, -1.0599502325057983, 0.8686069250106812, -1.1141618490219116, -0.01899157091975212, -1.22566556930542]
9 [[ 1.8208723 -0.59625745 0.15060106 -0.23053254 1.0344827 1.7016335 0.1968202 -0.48188782 0.3004466 -0.53336793 0.89704275 0.26869062 -1.0112952 0.0343281 0.32149622 -0.8189871 0.33571526 -0.57059044 -0.7577773 1.4743904 -0.01735806 -0.4116615 2.3637483 -0.6602343 -0.6101709 1.8139238 1.041902 -0.9006022 0.6082014 -0.23616463 -0.62204087 -0.524899 ]] LINKS 7 [1.8208723068237305, -0.5962574481964111, 0.15060105919837952, -0.23053254187107086, 1.034482717514038, 1.7016334533691406, 0.1968201994895935, -0.4818878173828125, 0.30044659972190857, -0.533367931842804, 0.8970427513122559, 0.2686906158924103, -1.011295199394226, 0.034328099340200424, 0.32149621844291687, -0.8189870715141296, 0.33571526408195496, -0.5705904364585876, -0.7577772736549377, 1.4743903875350952, -0.01735806092619896, -0.4116615056991577, 2.36374831199646, -0.660234272480011, -0.6101709008216858, 1.8139238357543945, 1.04190194606781, -0.9006022214889526, 0.6082013845443726, -0.2361646294593811, -0.622040867805481, -0.5248990058898926]
10 [[ 1.3968729 2.340378 0.567814 0.5684975 -0.8795973 -1.0090083 -0.01301649 1.9747832 -0.3257468 0.76960254 1.402801 0.39424965 0.06948688 -0.51904166 1.0961802 -0.34335235 -1.8730735 -1.0801809 0.2751789 -0.96998405 -1.9822828 2.1046805 1.3094746 0.5770473 -0.19693638 1.8973799 0.99536693 0.2668684 0.28499565 -0.9838486 -0.19578832 -0.46982569]] NEE 8 [1.396872878074646, 2.3403780460357666, 0.5678139925003052, 0.5684974789619446, -0.8795973062515259, -1.0090082883834839, -0.013016488403081894, 1.974783182144165, -0.3257468044757843, 0.7696025371551514, 1.4028010368347168, 0.39424964785575867, 0.06948687881231308, -0.5190416574478149, 1.0961802005767822, -0.343352347612381, -1.8730734586715698, -1.0801808834075928, 0.2751789093017578, -0.9699840545654297, -1.9822827577590942, 2.1046805381774902, 1.3094745874404907, 0.5770472884178162, -0.19693638384342194, 1.8973798751831055, 0.9953669309616089, 0.2668684124946594, 0.2849956452846527, -0.9838485717773438, -0.19578832387924194, -0.4698256850242615]
11 [[ 1.8143647 -0.9315294 0.28706568 -1.0938002 -0.2451038 0.42331484 -0.31653756 0.00268575 0.89822304 1.0639151 1.7406261 1.1707302 -1.6639041 0.30558044 0.22984704 -1.0260515 -0.08818819 -0.14696571 -1.5269523 0.55015475 -0.12936743 0.93951946 1.3917109 -0.62517196 -1.3869016 1.4833041 1.4647726 -1.0285482 -0.3704591 1.4672718 0.40315557 0.04754414]] RECHTS 9 [1.8143646717071533, -0.9315294027328491, 0.28706568479537964, -1.0938001871109009, -0.24510380625724792, 0.4233148396015167, -0.3165375590324402, 0.002685748040676117, 0.8982230424880981, 1.0639151334762573, 1.7406260967254639, 1.1707302331924438, -1.663904070854187, 0.3055804371833801, 0.2298470437526703, -1.0260515213012695, -0.08818819373846054, -0.14696571230888367, -1.5269522666931152, 0.5501547455787659, -0.1293674260377884, 0.939519464969635, 1.391710877418518, -0.625171959400177, -1.386901617050171, 1.4833041429519653, 1.4647725820541382, -1.028548240661621, -0.37045910954475403, 1.4672718048095703, 0.4031555652618408, 0.04754413664340973]
12 [[ 1.681777 -0.79441184 0.07110912 -0.01304399 1.110081 2.1115103 -0.27283993 -1.5035654 0.55927217 -0.9903696 0.05718866 0.19891238 -0.7289791 0.71738267 -0.41428873 -0.58389515 1.1087953 -0.6616588 -0.6330258 1.5830203 0.60013187 -1.0640136 2.3917787 -0.70012283 -0.7359694 0.9847692 0.94458187 -0.5915589 0.3662199 0.39359796 -0.6499281 -0.30565363]] SALUUT 10 [1.681777000427246, -0.794411838054657, 0.07110912352800369, -0.013043994084000587, 1.1100809574127197, 2.1115102767944336, -0.2728399336338043, -1.5035654306411743, 0.5592721700668335, -0.9903696179389954, 0.05718865618109703, 0.1989123821258545, -0.7289791107177734, 0.7173826694488525, -0.414288729429245, -0.5838951468467712, 1.1087952852249146, -0.6616588234901428, -0.6330258250236511, 1.5830203294754028, 0.6001318693161011, -1.0640136003494263, 2.3917787075042725, -0.7001228332519531, -0.7359694242477417, 0.9847692251205444, 0.9445818662643433, -0.5915588736534119, 0.3662199079990387, 0.39359796047210693, -0.649928092956543, -0.3056536316871643]
13 [[ 0.56773967 -0.6260798 0.23771092 -0.10016712 0.86817515 0.92371875 -0.16313902 -0.3785063 -0.41955388 -0.586288 -0.33712262 -0.07999519 -0.98128295 0.18787864 -0.36432508 -0.06605241 0.00710656 -0.36359936 -0.00642961 1.257384 0.64647824 -1.1618687 1.3707273 -0.46546304 -0.14522405 0.29306158 0.41898865 -0.12311607 0.24604419 -0.48434448 -0.79076517 0.753163 ]] SLECHT 11 [0.5677396655082703, -0.626079797744751, 0.23771092295646667, -0.10016711801290512, 0.8681751489639282, 0.9237187504768372, -0.16313901543617249, -0.37850630283355713, -0.41955387592315674, -0.5862879753112793, -0.3371226191520691, -0.07999519258737564, -0.9812829494476318, 0.18787863850593567, -0.36432507634162903, -0.06605241447687149, 0.007106557488441467, -0.36359935998916626, -0.006429608445614576, 1.257383942604065, 0.6464782357215881, -1.161868691444397, 1.370727300643921, -0.4654630422592163, -0.14522404968738556, 0.29306158423423767, 0.4189886450767517, -0.12311606854200363, 0.24604418873786926, -0.484344482421875, -0.7907651662826538, 0.7531629800796509]
14 [[ 0.9251696 -2.8536968 0.6521715 -1.8256427 0.44136596 0.33299872 -1.4674346 -0.7897836 0.47153538 0.1437631 0.7396151 1.0851275 -1.2100328 1.626544 0.69652647 0.0048107 0.13927785 -0.63199776 -3.0121155 2.1738136 -0.7935879 0.2744053 0.281097 -1.0899199 -0.7062637 -0.04824958 0.3436054 -0.69366074 0.22799876 0.8932708 0.51823634 -0.08065546]] SMAKELIJK 12 [0.9251695871353149, -2.853696823120117, 0.6521714925765991, -1.825642704963684, 0.44136595726013184, 0.33299872279167175, -1.4674346446990967, -0.7897835969924927, 0.4715353846549988, 0.14376309514045715, 0.7396150827407837, 1.0851274728775024, -1.2100328207015991, 1.6265439987182617, 0.6965264678001404, 0.004810698330402374, 0.139277845621109, -0.6319977641105652, -3.012115478515625, 2.173813581466675, -0.7935879230499268, 0.274405300617218, 0.2810969948768616, -1.089919924736023, -0.7062637209892273, -0.04824957996606827, 0.3436053991317749, -0.6936607360839844, 0.2279987633228302, 0.8932707905769348, 0.5182363390922546, -0.08065545558929443]
15 [[-1.460053 2.9393604 0.5533726 1.3786266 -1.5367306 -1.2867874 -1.2442408 0.62938243 -0.7092637 -1.0278705 -1.3296479 -0.8105969 0.69997054 0.16618343 0.5113248 0.15563956 -0.4815752 -0.10130378 1.4232539 -1.2767068 -0.3229611 0.45140803 -1.3901687 1.0403453 1.454544 -1.308266 -1.145199 0.661388 -0.20519465 -1.0480955 -0.04290716 -0.36275834]] SORRY 13 [-1.4600529670715332, 2.9393603801727295, 0.5533726215362549, 1.3786265850067139, -1.5367306470870972, -1.2867873907089233, -1.2442407608032227, 0.6293824315071106, -0.7092636823654175, -1.027870535850525, -1.3296478986740112, -0.8105968832969666, 0.699970543384552, 0.16618342697620392, 0.5113248229026794, 0.15563955903053284, -0.4815751910209656, -0.10130377858877182, 1.4232538938522339, -1.2767068147659302, -0.32296109199523926, 0.4514080286026001, -1.3901686668395996, 1.040345311164856, 1.454543948173523, -1.308266043663025, -1.145198941230774, 0.6613879799842834, -0.2051946520805359, -1.048095464706421, -0.04290715977549553, -0.3627583384513855]
16 [[ 1.0733138 -0.66348886 0.36737278 0.01765811 0.5730918 1.7617474 -0.45580968 -1.0973133 0.4730748 -0.4665998 0.48594972 0.548745 -0.91712314 0.4846773 0.3774526 -0.5996147 0.28750947 -1.0166394 -0.6239797 1.7935088 -0.03666669 -0.51941574 2.045076 -0.9045104 -0.6312211 0.9698636 1.0522215 -0.14772844 0.7406032 0.2901447 -0.48710442 -0.4619298 ]] TOT-ZIENS 14 [1.07331383228302, -0.6634888648986816, 0.3673727810382843, 0.017658105120062828, 0.5730918049812317, 1.7617473602294922, -0.45580968260765076, -1.0973132848739624, 0.4730747938156128, -0.46659979224205017, 0.48594972491264343, 0.5487449765205383, -0.9171231389045715, 0.4846773147583008, 0.3774526119232178, -0.599614679813385, 0.2875094711780548, -1.0166393518447876, -0.6239796876907349, 1.793508768081665, -0.036666687577962875, -0.5194157361984253, 2.0450758934020996, -0.9045103788375854, -0.6312211155891418, 0.9698635935783386, 1.0522215366363525, -0.14772844314575195, 0.7406032085418701, 0.2901447117328644, -0.4871044158935547, -0.4619297981262207]

97
export_embeddings.py Normal file
View File

@ -0,0 +1,97 @@
import multiprocessing
import os
import torch
import argparse
from datasets.dataset_loader import LocalDatasetLoader
from datasets.embedding_dataset import SLREmbeddingDataset
from torch.utils.data import DataLoader
from datasets import SLREmbeddingDataset, collate_fn_padd
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
import numpy as np
import random
import pandas as pd
seed = 43
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
generator = torch.Generator()
generator.manual_seed(seed)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
generator = torch.Generator()
generator.manual_seed(seed)
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def parse_args():
parser = argparse.ArgumentParser(description='Export embeddings')
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
parser.add_argument('--output', type=str, default=None, help='Path to output')
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
parser.add_argument('--format', type=str, default='csv', help='Format of the output file (csv, json)')
args = parser.parse_args()
return args
args = parse_args()
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
# load the model
checkpoint = torch.load(args.checkpoint, map_location=device)
model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
).to(device)
model.load_state_dict(checkpoint["state_dict"])
dataset_loader = LocalDatasetLoader()
dataset = SLREmbeddingDataset(args.dataset, triplet=False, augmentations=False)
data_loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=collate_fn_padd,
pin_memory=torch.cuda.is_available(),
#num_workers=0, # Uncomment this line (and comment out next line) if you want to disable multithreading
num_workers=multiprocessing.cpu_count(),
worker_init_fn=seed_worker,
generator=generator,
)
embeddings = []
k = 0
with torch.no_grad():
for i, (inputs, labels, masks) in enumerate(data_loader):
k += 1
inputs = inputs.to(device)
masks = masks.to(device)
outputs = model(inputs, masks)
for n in range(outputs.shape[0]):
embeddings.append(outputs[n].cpu().numpy())
df = pd.read_csv(args.dataset)
df["embeddings"] = embeddings
df = df[['embeddings', 'label_name', 'labels']]
df['embeddings2'] = df['embeddings'].apply(lambda x: x.tolist()[0])
if args.format == 'json':
df.to_json(args.output, orient='records')
elif args.format == 'csv':
df.to_csv(args.output, index=False)

71
export_model.py Normal file
View File

@ -0,0 +1,71 @@
import numpy as np
import onnx
import torch
import torchvision
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
# set parameters of the model
model_name = 'embedding_model'
output=32
# load PyTorch model from .pth file
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda")
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
).to(device)
model.load_state_dict(checkpoint["state_dict"])
# set model to evaluation mode
model.eval()
model_export = "onnx"
if model_export == "coreml":
dummy_input = torch.randn(1, 10, 54, 2)
# set device for dummy input
dummy_input = dummy_input.to(device)
traced_model = torch.jit.trace(model, dummy_input)
out = traced_model(dummy_input)
import coremltools as ct
# Convert to Core ML
coreml_model = ct.convert(
traced_model,
inputs=[ct.TensorType(name="input", shape=dummy_input.shape)],
)
# Save Core ML model
coreml_model.save("out-models/" + model_name + ".mlmodel")
else:
# create dummy input tensor
dummy_input = torch.randn(1, 10, 54, 2)
# set device for dummy input
dummy_input = dummy_input.to(device)
# export model to ONNX format
output_file = 'models/' + model_name + '.onnx'
torch.onnx.export(model, dummy_input, output_file, input_names=['input'], output_names=['output'])
torch.onnx.export(model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
'out-models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=9, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['X'], # the model's input names
output_names = ['Y'] # the model's output names
)
# load exported ONNX model for verification
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)

73
hyperparam_opt.py Normal file
View File

@ -0,0 +1,73 @@
from clearml.automation import UniformParameterRange, UniformIntegerParameterRange, DiscreteParameterRange
from clearml.automation import HyperParameterOptimizer
from clearml.automation.optuna import OptimizerOptuna
from optuna.pruners import HyperbandPruner, MedianPruner
from clearml import Task
task = Task.init(
project_name='SpoterEmbedding',
task_name='Automatic Hyper-Parameter Optimization',
task_type=Task.TaskTypes.optimizer,
reuse_last_task_id=False
)
optimizer = HyperParameterOptimizer(
# specifying the task to be optimized, task must be in system already so it can be cloned
base_task_id="4504e0b3ec6745249d3d4c94d3d40652",
# setting the hyperparameters to optimize
hyper_parameters=[
# epochs:
DiscreteParameterRange('Args/epochs', [200]),
# learning rate
UniformParameterRange('Args/lr', 0.000001, 0.01),
# optimizer
DiscreteParameterRange('Args/optimizer', ['ADAM', 'SGD']),
# vector length
UniformIntegerParameterRange('Args/vector_length', 10, 100),
],
# setting the objective metric we want to maximize/minimize
objective_metric_title='train_loss',
objective_metric_series='loss',
objective_metric_sign='min',
# setting optimizer
optimizer_class=OptimizerOptuna,
# configuring optimization parameters
execution_queue='default',
optimization_time_limit=360,
compute_time_limit=480,
total_max_jobs=20,
min_iteration_per_job=0,
max_iteration_per_job=150000,
pool_period_min=0.1,
save_top_k_tasks_only=3,
optuna_pruner=MedianPruner(),
)
def job_complete_callback(
job_id, # type: str
objective_value, # type: float
objective_iteration, # type: int
job_parameters, # type: dict
top_performance_job_id # type: str
):
print('Job completed!', job_id, objective_value, objective_iteration, job_parameters)
if job_id == top_performance_job_id:
print('WOOT WOOT we broke the record! Objective reached {}'.format(objective_value))
task.execute_remotely(queue_name='hypertuning', exit_process=True)
optimizer.set_report_period(0.3)
optimizer.start(job_complete_callback=job_complete_callback)
optimizer.wait()
top_exp = optimizer.get_top_experiments(top_k=3)
print([t.id for t in top_exp])
optimizer.stop()

View File

@ -61,20 +61,25 @@ def map_blazepose_keypoint(column):
return f"{mapped}_{hand}{suffix}"
def map_blazepose_df(df):
def map_blazepose_df(df, rename=True):
to_drop = []
if rename:
renamings = {}
for column in df.columns:
mapped_column = map_blazepose_keypoint(column)
if mapped_column:
renamings[column] = mapped_column
else:
to_drop.append(column)
df = df.rename(columns=renamings)
for index, row in df.iterrows():
sequence_size = len(row["leftEar_Y"])
lsx = row["leftShoulder_X"]
rsx = row["rightShoulder_X"]
lsy = row["leftShoulder_Y"]
rsy = row["rightShoulder_Y"]
# convert all to list
lsx = lsx[1:-1].split(",")
rsx = rsx[1:-1].split(",")
lsy = lsy[1:-1].split(",")
rsy = rsy[1:-1].split(",")
sequence_size = len(lsx)
neck_x = []
neck_y = []
# Treat each element of the sequence (analyzed frame) individually
@ -84,4 +89,5 @@ def map_blazepose_df(df):
df.loc[index, "neck_X"] = str(neck_x)
df.loc[index, "neck_Y"] = str(neck_y)
df.drop(columns=to_drop, inplace=True)
return df

View File

@ -5,30 +5,30 @@ import pandas as pd
from normalization.hand_normalization import normalize_hands_full
from normalization.body_normalization import normalize_body_full
DATASET_PATH = './data/wlasl'
DATASET_PATH = './data/processed'
# Load the dataset
df = pd.read_csv(os.path.join(DATASET_PATH, "WLASL100_train.csv"), encoding="utf-8")
df = pd.read_csv(os.path.join(DATASET_PATH, "spoter_train.csv"), encoding="utf-8")
print(df.head())
print(df.columns)
# Retrieve metadata
video_size_heights = df["video_height"].to_list()
video_size_widths = df["video_width"].to_list()
# video_size_heights = df["video_height"].to_list()
# video_size_widths = df["video_width"].to_list()
# Delete redundant (non-related) properties
del df["video_height"]
del df["video_width"]
# del df["video_height"]
# del df["video_width"]
# Temporarily remove other relevant metadata
labels = df["labels"].to_list()
video_fps = df["fps"].to_list()
signs = df["sign"].to_list()
del df["labels"]
del df["fps"]
del df["split"]
del df["video_id"]
del df["label_name"]
del df["length"]
del df["sign"]
del df["path"]
del df["participant_id"]
del df["sequence_id"]
# Convert the strings into lists
@ -41,7 +41,7 @@ for column in df.columns:
# Perform the normalizations
df = normalize_hands_full(df)
df, invalid_row_indexes = normalize_body_full(df)
# df, invalid_row_indexes = normalize_body_full(df)
# Clear lists of items from deleted rows
# labels = [t for i, t in enumerate(labels) if i not in invalid_row_indexes]
@ -49,6 +49,6 @@ df, invalid_row_indexes = normalize_body_full(df)
# Return the metadata back to the dataset
df["labels"] = labels
df["fps"] = video_fps
df["sign"] = signs
df.to_csv(os.path.join(DATASET_PATH, "wlasl_train_norm.csv"), encoding="utf-8", index=False)
df.to_csv(os.path.join(DATASET_PATH, "spoter_train_norm.csv"), encoding="utf-8", index=False)

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"id": "c20f7fd5",
"metadata": {},
"outputs": [],
@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"id": "ada032d0",
"metadata": {},
"outputs": [],
@ -22,13 +22,12 @@
"import os\n",
"import os.path as op\n",
"import pandas as pd\n",
"import json\n",
"import base64"
"import json"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"id": "05682e73",
"metadata": {},
"outputs": [],
@ -38,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 8,
"id": "fede7684",
"metadata": {},
"outputs": [],
@ -48,7 +47,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"id": "ce531994",
"metadata": {},
"outputs": [],
@ -64,7 +63,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"id": "f4a2d672",
"metadata": {},
"outputs": [],
@ -87,17 +86,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"id": "1d9db764",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f29f89e3ed0>"
"<torch._C.Generator at 0x7f010919d710>"
]
},
"execution_count": 7,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@ -119,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 12,
"id": "71224139",
"metadata": {},
"outputs": [],
@ -155,7 +154,7 @@
"# checkpoint = torch.load(model.get_weights())\n",
"\n",
"## Set your path to checkoint here\n",
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_6.pth\"\n",
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_1105.pth\"\n",
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
"\n",
"model = SPOTER_EMBEDDINGS(\n",
@ -169,27 +168,28 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 24,
"id": "ba6b58f0",
"metadata": {},
"outputs": [],
"source": [
"SL_DATASET = 'wlasl' # or 'lsa'\n",
"if SL_DATASET == 'wlasl':\n",
"SL_DATASET = 'basic-signs' # or 'wlasl'\n",
"\n",
"if SL_DATASET == 'fingerspelling':\n",
" dataset_name = \"fingerspelling\"\n",
" split_dataset_path = \"fingerspelling_{}.csv\"\n",
"elif SL_DATASET == 'wlasl':\n",
" dataset_name = \"wlasl\"\n",
" num_classes = 100\n",
" split_dataset_path = \"WLASL100_train.csv\"\n",
"else:\n",
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
" num_classes = 64\n",
" split_dataset_path = \"LSA64_{}.csv\"\n",
" \n",
" split_dataset_path = \"WLASL100_{}.csv\"\n",
"elif SL_DATASET == 'basic-signs':\n",
" dataset_name = \"basic-signs\"\n",
" split_dataset_path = \"basic-signs_{}.csv\"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 25,
"id": "5643a72c",
"metadata": {},
"outputs": [],
@ -228,7 +228,7 @@
"outputs": [],
"source": [
"dataloaders = {}\n",
"splits = ['train', 'val']\n",
"splits = ['train', 'val']\n",
"dfs = {}\n",
"for split in splits:\n",
" split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n",
@ -269,6 +269,8 @@
" for i, (inputs, labels, masks) in enumerate(dataloader):\n",
" k += 1\n",
" inputs = inputs.to(device)\n",
" \n",
"\n",
" masks = masks.to(device)\n",
" outputs = model(inputs, masks)\n",
" for n in range(outputs.shape[0]):\n",
@ -285,7 +287,7 @@
{
"data": {
"text/plain": [
"(810, 810)"
"(164, 164)"
]
},
"execution_count": 19,
@ -299,7 +301,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 21,
"id": "ab83c6e2",
"metadata": {
"lines_to_next_cell": 2
@ -311,6 +313,70 @@
" df['embeddings'] = embeddings_split[split]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0b9fb9c2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 [1.7327625, -3.015248, -1.4775522, -0.7505071,...\n",
"1 [2.0936582, -0.596195, -0.7918601, -0.15896143...\n",
"2 [-1.4007742, -0.9608915, 1.3294879, -0.5185398...\n",
"3 [1.3280737, -3.299126, -1.0110444, -1.2528414,...\n",
"4 [-0.071124956, -0.79259753, 0.7182858, 0.38130...\n",
" ... \n",
"159 [-1.5968355, 1.9617733, 0.28859574, 1.256657, ...\n",
"160 [0.44801116, -1.8377966, 1.1004394, -1.195648,...\n",
"161 [2.0584257, 1.6986116, 0.5129896, 0.27279535, ...\n",
"162 [1.6695516, -2.967027, -1.5715427, -0.77170163...\n",
"163 [1.4977738, -2.6278958, -1.6123883, -0.8420623...\n",
"Name: embeddings, Length: 164, dtype: object\n",
"0 TOT-ZIENS\n",
"1 GOED\n",
"2 GOEDENACHT\n",
"3 NEE\n",
"4 SLECHT\n",
" ... \n",
"159 SORRY\n",
"160 GOEDEMORGEN\n",
"161 LINKS\n",
"162 TOT-ZIENS\n",
"163 GOED\n",
"Name: label_name, Length: 164, dtype: object\n",
"0 0\n",
"1 1\n",
"2 2\n",
"3 3\n",
"4 4\n",
" ..\n",
"159 7\n",
"160 5\n",
"161 13\n",
"162 0\n",
"163 1\n",
"Name: labels, Length: 164, dtype: int64\n"
]
}
],
"source": [
"print(dfs['train'][\"embeddings\"])\n",
"print(dfs['train'][\"label_name\"])\n",
"print(dfs['train'][\"labels\"])\n",
"\n",
"# only keep these columns\n",
"dfs['train'] = dfs['train'][['embeddings', 'label_name', 'labels']]\n",
"\n",
"# convert embeddings to string\n",
"dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n",
"\n",
"# save the dfs['train']\n",
"dfs['train'].to_csv(f'../data/{dataset_name}/embeddings.csv', index=False)"
]
},
{
"cell_type": "markdown",
"id": "2951638d",
@ -322,7 +388,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 23,
"id": "7399b8ae",
"metadata": {},
"outputs": [
@ -331,16 +397,16 @@
"output_type": "stream",
"text": [
"Using centroids only\n",
"Top-1 accuracy: 5.19 %\n",
"Top-5 embeddings class match: 17.65 % (Picks any class in the 5 closest embeddings)\n",
"Top-1 accuracy: 80.00 %\n",
"Top-5 embeddings class match: 93.33 % (Picks any class in the 5 closest embeddings)\n",
"\n",
"################################\n",
"\n",
"Using all embeddings\n",
"Top-1 accuracy: 5.31 %\n",
"5-nn accuracy: 5.56 % (Picks the class that appears most often in the 5 closest embeddings)\n",
"Top-5 embeddings class match: 15.43 % (Picks any class in the 5 closest embeddings)\n",
"Top-5 unique class match: 15.56 % (Picks the 5 closest distinct classes)\n",
"Top-1 accuracy: 80.00 %\n",
"5-nn accuracy: 80.00 % (Picks the class that appears most often in the 5 closest embeddings)\n",
"Top-5 embeddings class match: 86.67 % (Picks any class in the 5 closest embeddings)\n",
"Top-5 unique class match: 93.33 % (Picks the 5 closest distinct classes)\n",
"\n",
"################################\n",
"\n"
@ -375,13 +441,13 @@
" sorted_labels = labels[argsort]\n",
" if sorted_labels[0] == true_label:\n",
" top1 += 1\n",
" if use_centroids:\n",
" good_samples.append(df_val.loc[i, 'video_id'])\n",
" else:\n",
" good_samples.append((df_val.loc[i, 'video_id'],\n",
" df_train.loc[argsort[0], 'video_id'],\n",
" i,\n",
" argsort[0]))\n",
" # if use_centroids:\n",
" # good_samples.append(df_val.loc[i, 'video_id'])\n",
" # else:\n",
" # good_samples.append((df_val.loc[i, 'video_id'],\n",
" # df_train.loc[argsort[0], 'video_id'],\n",
" # i,\n",
" # argsort[0]))\n",
"\n",
"\n",
" if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n",

File diff suppressed because one or more lines are too long

93
predictions/k_nearest.py Normal file
View File

@ -0,0 +1,93 @@
import numpy as np
from collections import Counter
# TODO scaling van distance tov intra distances?
# TODO efficientere manier om k=1 te doen
def minkowski_distance_p(x, y, p=2):
x = np.asarray(x)
y = np.asarray(y)
# Find the smallest common datatype with float64 (return type of this
# function) - addresses #10262.
# Don't just cast to float64 for complex input case.
common_datatype = np.promote_types(np.promote_types(x.dtype, y.dtype),
'float64')
# Make sure x and y are NumPy arrays of correct datatype.
x = x.astype(common_datatype)
y = y.astype(common_datatype)
if p == np.inf:
return np.amax(np.abs(y - x), axis=-1)
elif p == 1:
return np.sum(np.abs(y - x), axis=-1)
else:
return np.sum(np.abs(y - x) ** p, axis=-1)
def minkowski_distance(x, y, p=2):
x = np.asarray(x)
y = np.asarray(y)
if p == np.inf or p == 1:
return minkowski_distance_p(x, y, p)
else:
return minkowski_distance_p(x, y, p) ** (1. / p)
class KNearestNeighbours:
def __init__(self, k=5):
self.k = k
self.embeddings = None
self.embeddings_list = None
def set_embeddings(self, embeddings):
self.embeddings = embeddings
df = embeddings.drop(columns=['labels', 'label_name', 'embeddings'])
# convert embedding from string to list of floats
df["embeddings"] = df["embeddings2"].apply(lambda x: [float(i) for i in x[1:-1].split(", ")])
# drop embeddings2
df = df.drop(columns=['embeddings2'])
# to list
self.embeddings_list = df["embeddings"].tolist()
def distance_matrix(self, keypoints, p=2, threshold=1000000):
x = np.array(keypoints)
m, k = x.shape
y = np.asarray(self.embeddings_list)
n, kk = y.shape
if k != kk:
raise ValueError(f"x contains {k}-dimensional vectors but y contains "
f"{kk}-dimensional vectors")
if m * n * k <= threshold:
# print("Using minkowski_distance")
return minkowski_distance(x[:, np.newaxis, :], y[np.newaxis, :, :], p)
else:
result = np.empty((m, n), dtype=float) # FIXME: figure out the best dtype
if m < n:
for i in range(m):
result[i, :] = minkowski_distance(x[i], y, p)
else:
for j in range(n):
result[:, j] = minkowski_distance(x, y[j], p)
return result
def predict(self, key_points_embeddings):
# calculate distance matrix
dist_matrix = self.distance_matrix(key_points_embeddings, p=2, threshold=1000000)
# get the 5 closest matches and select the class that is most common and use the average distance as the score
# get the 5 closest matches
indeces = np.argsort(dist_matrix)[0][:self.k]
# get the labels
labels = self.embeddings["label_name"].iloc[indeces].tolist()
c = Counter(labels).most_common()[0][0]
# filter indeces to only include the most common label
indeces = [i for i in indeces if self.embeddings["label_name"].iloc[i] == c]
# get the average distance
score = np.mean(dist_matrix[0][indeces])
return c, score

86
predictions/plotting.py Normal file
View File

@ -0,0 +1,86 @@
import json
from matplotlib import pyplot as plt
def load_results():
with open("predictions/test_results/knn.json", 'r') as f:
results = json.load(f)
return results
def plot_all():
results = load_results()
print(f"average elapsed time to detect a sign: {get_general_elapsed_time(results)}")
plot_general_accuracy(results)
for label in results.keys():
plot_accuracy_per_label(results, label)
def general_accuracy(results):
label_accuracy = get_label_accuracy(results)
accuracy = []
amount = []
response = []
for label in label_accuracy.keys():
for index, value in enumerate(label_accuracy[label]):
if index >= len(accuracy):
accuracy.append(0)
amount.append(0)
accuracy[index] += label_accuracy[label][index]
amount[index] += 1
for a, b in zip(accuracy, amount):
if b < 5:
break
response.append(a / b)
return response
def plot_general_accuracy(results):
accuracy = general_accuracy(results)
plt.plot(accuracy)
plt.title = "General accuracy"
plt.ylabel('accuracy')
plt.xlabel('buffer')
plt.show()
def plot_accuracy_per_label(results, label):
accuracy = get_label_accuracy(results)
plt.plot(accuracy[label], label=label)
plt.titel = f"Accuracy per label {label}"
plt.ylabel('accuracy')
plt.xlabel('prediction')
plt.legend()
plt.show()
def get_label_accuracy(results):
accuracy = {}
amount = {}
response = {}
for label, predictions in results.items():
if label not in accuracy:
accuracy[label] = []
amount[label] = []
for prediction in predictions:
for index, value in enumerate(prediction["predictions"]):
if index >= len(accuracy[label]):
accuracy[label].append(0)
amount[label].append(0)
accuracy[label][index] += 1 if value["correct"] else 0
amount[label][index] += 1
for label in accuracy:
response[label] = []
for index, value in enumerate(accuracy[label]):
if amount[label][index] < 2:
break
response[label].append(accuracy[label][index] / amount[label][index])
return response
def get_general_elapsed_time(results):
label_time = get_label_elapsed_time(results)
return sum([label_time[label] for label in results]) / len(results)
def get_label_elapsed_time(results):
return {label: sum([result["elapsed_time"] for result in results[label]]) / len(results[label]) for label in results}
if __name__ == '__main__':
plot_all()

267
predictions/predictor.py Normal file
View File

@ -0,0 +1,267 @@
import cv2
import mediapipe as mp
import numpy as np
import pandas as pd
import torch
from predictions.k_nearest import KNearestNeighbours
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
from models import SPOTER_EMBEDDINGS
BODY_IDENTIFIERS = [
0,
33,
5,
2,
8,
7,
12,
11,
14,
13,
16,
15,
]
HAND_IDENTIFIERS = [
0,
8,
7,
6,
5,
12,
11,
10,
9,
16,
15,
14,
13,
20,
19,
18,
17,
4,
3,
2,
1,
]
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
class Predictor:
def __init__(self, embeddings_path, predictor_type):
# Initialize MediaPipe Hands model
self.holistic = mp.solutions.holistic.Holistic(
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
model_complexity=2
)
self.mp_holistic = mp.solutions.holistic
self.mp_drawing = mp.solutions.drawing_utils
# buffer = []
self.left_shoulder_index = 11
self.right_shoulder_index = 12
self.neck_index = 33
self.nose_index = 0
self.left_eye_index = 2
# load training embedding csv
self.embeddings = pd.read_csv(embeddings_path)
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
self.model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
).to(device)
self.model.load_state_dict(checkpoint["state_dict"])
if predictor_type is None:
self.predictor = KNearestNeighbours(1)
else:
self.predictor = predictor_type
self.predictor.set_embeddings(self.embeddings)
def extract_keypoints(self, image_orig):
image = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
results = self.holistic.process(image)
def extract_keypoints(lmks):
if lmks:
a = np.array([[float(lmk.x), float(lmk.y)] for lmk in lmks.landmark])
return a
return None
def calculate_neck(keypoints):
if keypoints is not None:
left_shoulder = keypoints[11]
right_shoulder = keypoints[12]
neck = [(float(left_shoulder[0]) + float(right_shoulder[0])) / 2,
(float(left_shoulder[1]) + float(right_shoulder[1])) / 2]
# add neck to keypoints
keypoints = np.append(keypoints, [neck], axis=0)
return keypoints
return None
pose = extract_keypoints(results.pose_landmarks)
pose = calculate_neck(pose)
if pose is None:
return None
pose_norm = self.normalize_pose(pose)
# filter out keypoints that are not in BODY_IDENTIFIERS and make sure they are in the correct order
pose_norm = pose_norm[BODY_IDENTIFIERS]
left_hand = extract_keypoints(results.left_hand_landmarks)
right_hand = extract_keypoints(results.right_hand_landmarks)
if left_hand is None and right_hand is None:
return None
# normalize hands
if left_hand is not None:
left_hand = self.normalize_hand(left_hand)
else:
left_hand = np.zeros((21, 2))
if right_hand is not None:
right_hand = self.normalize_hand(right_hand)
else:
right_hand = np.zeros((21, 2))
left_hand = left_hand[HAND_IDENTIFIERS]
right_hand = right_hand[HAND_IDENTIFIERS]
# combine pose and hands
pose_norm = np.append(pose_norm, left_hand, axis=0)
pose_norm = np.append(pose_norm, right_hand, axis=0)
# move interval
pose_norm -= 0.5
return pose_norm
# if we have the keypoints, normalize single body, keypoints is numpy array of (identifiers, 2)
def normalize_pose(self, keypoints):
left_shoulder = keypoints[self.left_shoulder_index]
right_shoulder = keypoints[self.right_shoulder_index]
neck = keypoints[self.neck_index]
nose = keypoints[self.nose_index]
# Prevent from even starting the analysis if some necessary elements are not present
if (left_shoulder[0] == 0 or right_shoulder[0] == 0
or (left_shoulder[0] == right_shoulder[0] and left_shoulder[1] == right_shoulder[1])) and (
neck[0] == 0 or nose[0] == 0 or (neck[0] == nose[0] and neck[1] == nose[1])):
return keypoints
if left_shoulder[0] != 0 and right_shoulder[0] != 0 and (
left_shoulder[0] != right_shoulder[0] or left_shoulder[1] != right_shoulder[1]):
shoulder_distance = ((((left_shoulder[0] - right_shoulder[0]) ** 2) + (
(left_shoulder[1] - right_shoulder[1]) ** 2)) ** 0.5)
head_metric = shoulder_distance
else:
neck_nose_distance = ((((neck[0] - nose[0]) ** 2) + ((neck[1] - nose[1]) ** 2)) ** 0.5)
head_metric = neck_nose_distance
# Set the starting and ending point of the normalization bounding box
starting_point = [keypoints[self.neck_index][0] - 3 * head_metric,
keypoints[self.left_eye_index][1] + head_metric]
ending_point = [keypoints[self.neck_index][0] + 3 * head_metric, starting_point[1] - 6 * head_metric]
if starting_point[0] < 0:
starting_point[0] = 0
if starting_point[1] < 0:
starting_point[1] = 0
if ending_point[0] < 0:
ending_point[0] = 0
if ending_point[1] < 0:
ending_point[1] = 0
# Normalize the keypoints
for i in range(len(keypoints)):
keypoints[i][0] = (keypoints[i][0] - starting_point[0]) / (ending_point[0] - starting_point[0])
keypoints[i][1] = (keypoints[i][1] - ending_point[1]) / (starting_point[1] - ending_point[1])
return keypoints
def normalize_hand(self, keypoints):
x_values = [keypoints[i][0] for i in range(len(keypoints)) if keypoints[i][0] != 0]
y_values = [keypoints[i][1] for i in range(len(keypoints)) if keypoints[i][1] != 0]
if not x_values or not y_values:
return keypoints
width, height = max(x_values) - min(x_values), max(y_values) - min(y_values)
if width > height:
delta_x = 0.1 * width
delta_y = delta_x + ((width - height) / 2)
else:
delta_y = 0.1 * height
delta_x = delta_y + ((height - width) / 2)
starting_point = (min(x_values) - delta_x, min(y_values) - delta_y)
ending_point = (max(x_values) + delta_x, max(y_values) + delta_y)
if ending_point[0] - starting_point[0] == 0 or ending_point[1] - starting_point[1] == 0:
return keypoints
# normalize keypoints
for i in range(len(keypoints)):
keypoints[i][0] = (keypoints[i][0] - starting_point[0]) / (ending_point[0] - starting_point[0])
keypoints[i][1] = (keypoints[i][1] - starting_point[1]) / (ending_point[1] - starting_point[1])
return keypoints
def get_embedding(self, keypoints):
# run model on frame
self.model.eval()
with torch.no_grad():
keypoints = torch.from_numpy(np.array([keypoints])).float().to(device)
new_embeddings = self.model(keypoints).cpu().numpy().tolist()[0]
return new_embeddings
def predict(self, embeddings):
return self.predictor.predict(embeddings)
def make_prediction(self, keypoints):
# run model on frame
self.model.eval()
with torch.no_grad():
keypoints = torch.from_numpy(np.array([keypoints])).float().to(device)
new_embeddings = self.model(keypoints).cpu().numpy().tolist()[0]
return self.predictor.predict(new_embeddings)
def validation(self):
# load validation data
validation_data = np.load('validation_data.npy', allow_pickle=True)
validation_labels = np.load('validation_labels.npy', allow_pickle=True)
# run model on validation data
self.model.eval()
with torch.no_grad():
validation_embeddings = self.model(torch.from_numpy(validation_data).float().to(device)).cpu().numpy()
# predict validation data
predictions = self.predictor.predict(validation_embeddings)
# calculate accuracy
correct = 0
for i in range(len(predictions)):
if predictions[i] == validation_labels[i]:
correct += 1
accuracy = correct / len(predictions)
print('Accuracy: ' + str(accuracy))

34
predictions/svm_model.py Normal file
View File

@ -0,0 +1,34 @@
from sklearn import svm
class SVM:
def __init__(self, type="ovo"):
self.label_name_to_label = None
self.clf = None
self.embeddings_list = None
self.labels = None
self.type = type
def set_embeddings(self, embeddings):
# convert embedding from string to list of floats
embeddings["embeddings"] = embeddings["embeddings2"].apply(lambda x: [float(i) for i in x[1:-1].split(", ")])
# drop embeddings2
df = embeddings.drop(columns=['embeddings2'])
# to list
self.embeddings_list = df["embeddings"].tolist()
self.labels = df["labels"].tolist()
self.label_name_to_label = df[["label_name", "labels"]]
self.label_name_to_label.columns = ["label_name", "label"]
self.label_name_to_label = self.label_name_to_label.drop_duplicates()
self.train()
def train(self):
self.clf = svm.SVC(decision_function_shape=self.type, probability=True)
self.clf.fit(self.embeddings_list, self.labels)
def predict(self, key_points_embeddings):
label = self.clf.predict(key_points_embeddings)
score = self.clf.predict_log_proba(key_points_embeddings)
# TODO fix dictionary
label = label.item()
return self.label_name_to_label.loc[self.label_name_to_label["label"] == label]["label_name"].iloc[0], score[0][label]

File diff suppressed because one or more lines are too long

137
predictions/validation.py Normal file
View File

@ -0,0 +1,137 @@
import json
import os
import time
import cv2
import numpy as np
from matplotlib import pyplot as plt
from predictions.k_nearest import KNearestNeighbours
from predictions.predictor import Predictor
from predictions.svm_model import SVM
buffer_size = 15
def predict_video(predictor, path_video):
# open mp4 video
cap = cv2.VideoCapture(path_video)
buffer = []
ret, img = cap.read() # read one frame from the 'capture' object; img is (H, W, C)
desired_fps = 15
original_fps = int(cap.get(cv2.CAP_PROP_FPS))
print("Original FPS: ", original_fps)
# Calculate the frame skipping rate based on desired frame rate
frame_skip = original_fps // desired_fps
if frame_skip == 0:
frame_skip = 1
print("Frame skip: ", frame_skip)
frame_number = 0
while img is not None:
pose = predictor.extract_keypoints(img)
if pose is not None and frame_number % frame_skip == 0:
buffer.append(pose)
frame_number += 1
ret, img = cap.read() # read one frame from the 'capture' object; img is (H, W, C)
print(len(buffer))
return buffer
def get_embeddings(predictor, buffer, name):
# check if file exists with name
# if os.path.exists("predictions/test_embeddings/" + name + ".csv"):
# print("Loading embeddings from file")
# # load embeddings from file
# with open("predictions/test_embeddings/" + name + ".csv", 'r') as f:
# embeddings = json.load(f)
# else:
embeddings = []
for index in range(buffer_size, len(buffer)):
embedding = predictor.get_embedding(buffer[index - buffer_size:index])
embeddings.append(embedding)
with open("predictions/test_embeddings/" + name + ".csv", 'w') as f:
json.dump(embeddings, f)
return embeddings
def compare_embeddings(predictor, embeddings, label_video, ):
results = []
for embedding in embeddings:
label, score = predictor.predict(embedding)
results.append({"label": label, "score": score, "label_video": label_video, "correct": label == label_video})
return results
def predict_video_files(predictor, path_video, label_video):
buffer = predict_video(predictor, path_video)
embeddings = get_embeddings(predictor, buffer, path_video.split("/")[-1].split(".")[0])
return compare_embeddings(predictor, embeddings, label_video)
def get_test_data(data_folder):
files = np.array([data_folder + f for f in os.listdir(data_folder) if f.endswith(".mp4")])
train_test = [f.split("/")[-1].split("!")[1] for f in files]
test_files = files[np.array(train_test) == "test"]
test_labels = [f.split("/")[-1].split("!")[0] for f in test_files]
return test_files, test_labels
def test_data(predictor, data_folder):
results = {}
for path_video, label_video in zip(*get_test_data(data_folder)):
print(path_video, label_video)
start_time = time.time()
prediction = predict_video_files(predictor, path_video, label_video)
end_time = time.time()
elapsed_time = end_time - start_time
# divide elapsed time by amount of predictions made so it represents an avarage execution time
if len(prediction) > 0:
elapsed_time /= len(prediction)
if label_video not in results:
results[label_video] = []
results[label_video].append({"predictions": prediction, "elapsed_time": elapsed_time, "video": path_video})
print("DONE")
return results
def plot_general_accuracy(results):
accuracy = []
amount = []
for result in results:
for index, value in enumerate(result[0]):
if len(accuracy) <= index:
accuracy.append(0)
amount.append(0)
accuracy[index] += 1 if value["correct"] else 0
amount[index] += 1
# plot the general accuracy
plt.plot(accuracy)
plt.show()
if __name__ == "__main__":
type_predictor = "knn"
if type_predictor == "knn":
k = 1
predictor_type = KNearestNeighbours(k)
elif type_predictor == "svm":
predictor_type = SVM()
else:
predictor_type = KNearestNeighbours(1)
# embeddings_path = 'embeddings/basic-signs/embeddings.csv'
embeddings_path = 'embeddings/fingerspelling/embeddings.csv'
predictor = Predictor(embeddings_path, predictor_type)
data_folder = '/home/tibe/Projects/design_project/sign-predictor/data/fingerspelling/data/'
results = test_data(predictor, data_folder)
# write results to a results json file
with open("predictions/test_results/" + type_predictor + ".json", 'w') as f:
json.dump(results, f)
print(results)
# plot_general_accuracy(results)

View File

@ -1,5 +1,5 @@
from argparse import ArgumentParser
from preprocessing.create_wlasl_landmarks_dataset import parse_create_args, create
from preprocessing.create_fingerspelling_dataset import parse_create_args, create
from preprocessing.extract_mediapipe_landmarks import parse_extract_args, extract

View File

@ -0,0 +1,174 @@
import os
import os.path as op
import json
import shutil
import cv2
import mediapipe as mp
import numpy as np
import pandas as pd
from utils import get_logger
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from normalization.blazepose_mapping import map_blazepose_df
BASE_DATA_FOLDER = 'data/'
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_hands = mp.solutions.hands
mp_holistic = mp.solutions.holistic
pose_landmarks = mp_holistic.PoseLandmark
hand_landmarks = mp_holistic.HandLandmark
def get_landmarks_names():
'''
Returns landmark names for mediapipe holistic model
'''
pose_lmks = ','.join([f'{lmk.name.lower()}_x,{lmk.name.lower()}_y' for lmk in pose_landmarks])
left_hand_lmks = ','.join([f'left_hand_{lmk.name.lower()}_x,left_hand_{lmk.name.lower()}_y'
for lmk in hand_landmarks])
right_hand_lmks = ','.join([f'right_hand_{lmk.name.lower()}_x,right_hand_{lmk.name.lower()}_y'
for lmk in hand_landmarks])
lmks_names = f'{pose_lmks},{left_hand_lmks},{right_hand_lmks}'
return lmks_names
def convert_to_str(arr, precision=6):
if isinstance(arr, np.ndarray):
values = []
for val in arr:
if val == 0:
values.append('0')
else:
values.append(f'{val:.{precision}f}')
return f"[{','.join(values)}]"
else:
return str(arr)
def parse_create_args(parser):
parser.add_argument('--landmarks-dataset', '-lmks', required=True,
help='Path to folder with landmarks npy files. \
You need to run `extract_mediapipe_landmarks.py` script first')
parser.add_argument('--dataset-folder', '-df', default='data/wlasl',
help='Path to folder where original `WLASL_v0.3.json` and `id_to_label.json` are stored. \
Note that final CSV files will be saved in this folder too.')
parser.add_argument('--videos-folder', '-videos', default=None,
help='Path to folder with videos. If None, then no information of videos (fps, length, \
width and height) will be stored in final csv file')
parser.add_argument('--num-classes', '-nc', default=100, type=int, help='Number of classes to use in WLASL dataset')
parser.add_argument('--create-new-split', action='store_true')
parser.add_argument('--test-size', '-ts', default=0.25, type=float,
help='Test split percentage size. Only required if --create-new-split is set')
# python3 preprocessing.py --landmarks-dataset=data/landmarks -videos data/wlasl/videos
def create(args):
logger = get_logger(__name__)
landmarks_dataset = args.landmarks_dataset
videos_folder = args.videos_folder
dataset_folder = args.dataset_folder
num_classes = args.num_classes
test_size = args.test_size
os.makedirs(dataset_folder, exist_ok=True)
# shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/id_to_label.json'), dataset_folder)
# shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/WLASL_v0.3.json'), dataset_folder)
# get files in landmarks_dataset folder
landmarks_files = os.listdir(landmarks_dataset)
video_data = []
for i, file in enumerate(tqdm(landmarks_files)):
# split by !
label = file.split('!')[0]
subset = file.split('!')[1].split('.')[0]
# remove npy and set mp4
video_id = file.replace('.npy', "")
video_dict = {'video_id': video_id,
'label_name': label,
'split': subset}
if videos_folder is not None:
cap = cv2.VideoCapture(op.join(videos_folder, f'{video_id}.mp4'))
if not cap.isOpened():
logger.warning(f'Video {video_id}.mp4 not found')
continue
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
fps = cap.get(cv2.CAP_PROP_FPS)
length = cap.get(cv2.CAP_PROP_FRAME_COUNT) / float(cap.get(cv2.CAP_PROP_FPS))
video_info = {'video_width': width,
'video_height': height,
'fps': fps,
'length': length}
video_dict.update(video_info)
video_data.append(video_dict)
df_video = pd.DataFrame(video_data)
video_ids = df_video['video_id'].unique()
lmks_data = []
lmks_names = get_landmarks_names().split(',')
# get labels from df_video
labels = df_video['label_name'].unique()
# map labels to ids
label_to_id = {label: i for i, label in enumerate(labels)}
# add label_id column to df_video
df_video['labels'] = df_video['label_name'].map(label_to_id)
# export to json file as id to label
id_to_label = {i: label for label, i in label_to_id.items()}
with open(op.join(dataset_folder, 'id_to_label.json'), 'w') as f:
json.dump(id_to_label, f, indent=4)
for video_id in video_ids:
lmk_fn = op.join(landmarks_dataset, f'{video_id}.npy')
if not op.exists(lmk_fn):
logger.warning(f'{lmk_fn} file not found. Skipping')
continue
lmk = np.load(lmk_fn).T
lmks_dict = {'video_id': video_id}
for lmk_, name in zip(lmk, lmks_names):
lmks_dict[name] = lmk_
lmks_data.append(lmks_dict)
df_lmks = pd.DataFrame(lmks_data)
df = pd.merge(df_video, df_lmks)
aux_columns = ['split', 'video_id', 'labels', 'label_name']
if videos_folder is not None:
aux_columns += ['video_width', 'video_height', 'fps', 'length']
df_aux = df[aux_columns]
df = map_blazepose_df(df)
df = pd.concat([df, df_aux], axis=1)
if args.create_new_split:
df_train, df_test = train_test_split(df, test_size=test_size, stratify=df['labels'], random_state=42)
print(f'Num classes: {num_classes}')
print(df_train['labels'].value_counts())
print(df_test['labels'].value_counts())
assert set(df_train['labels'].unique()) == set(df_test['labels'].unique(
)), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \
the --create-new-split flag'
for split, df_split in zip(['train', 'val'],
[df_train, df_test]):
fn_out = op.join(dataset_folder, f'{split}.csv')
(df_split.reset_index(drop=True)
.applymap(convert_to_str)
.to_csv(fn_out, index=False))
else:
fn_out = op.join(dataset_folder, 'train.csv')
(df.reset_index(drop=True)
.applymap(convert_to_str)
.to_csv(fn_out, index=False))

View File

@ -4,6 +4,8 @@ import pandas as pd
from tqdm.auto import tqdm
import json
from normalization.blazepose_mapping import map_blazepose_df
def create(train_landmark_files, train_csv, dataset_folder, test_size):
os.makedirs(dataset_folder, exist_ok=True)
@ -17,15 +19,15 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
mapping = {
'pose_0': 'nose',
'pose_1': 'leftEye',
'pose_2': 'rightEye',
'pose_3': 'leftEar',
'pose_4': 'rightEar',
'pose_5': 'leftShoulder',
'pose_6': 'rightShoulder',
'pose_7': 'leftElbow',
'pose_8': 'rightElbow',
'pose_9': 'leftWrist',
'pose_10': 'rightWrist',
'pose_4': 'rightEye',
'pose_7': 'leftEar',
'pose_8': 'rightEar',
'pose_11': 'leftShoulder',
'pose_12': 'rightShoulder',
'pose_13': 'leftElbow',
'pose_14': 'rightElbow',
'pose_15': 'leftWrist',
'pose_16': 'rightWrist',
'left_hand_0': 'wrist_left',
'left_hand_1': 'thumbCMC_left',
@ -77,7 +79,7 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
columns.append(f'{v}_X')
columns.append(f'{v}_Y')
for _, row in tqdm(train_df.head(6000).iterrows(), total=6000):
for _, row in tqdm(train_df.head(10000).iterrows(), total=10000):
path, participant_id, sequence_id, sign = row['path'], row['participant_id'], row['sequence_id'], row['sign']
parquet_file = os.path.join(train_landmark_files, str(participant_id), f"{sequence_id}.parquet")
@ -136,6 +138,7 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
video_data.append(new_landmark_data)
video_data = pd.concat(video_data, axis=0, ignore_index=True)
video_data = map_blazepose_df(video_data, rename=False)
video_data.to_csv(os.path.join(dataset_folder, 'spoter.csv'), index=False)
train_landmark_files = 'data/train_landmark_files'

View File

@ -110,6 +110,7 @@ def create(args):
'length': length}
video_dict.update(video_info)
video_data.append(video_dict)
df_video = pd.DataFrame(video_data)
video_ids = df_video['video_id'].unique()
lmks_data = []
@ -126,9 +127,7 @@ def create(args):
lmks_data.append(lmks_dict)
df_lmks = pd.DataFrame(lmks_data)
print(df_lmks)
df = pd.merge(df_video, df_lmks)
print(df)
aux_columns = ['split', 'video_id', 'labels', 'label_name']
if videos_folder is not None:
aux_columns += ['video_width', 'video_height', 'fps', 'length']

View File

@ -132,6 +132,12 @@ def extract(args):
ret, image_orig = cap.read()
height, width = image_orig.shape[:2]
landmarks_video = []
# make sure fps is 20 by determining the number of frames to be skipped
frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
frame_skip = (frame_rate // 20) - 1
with tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))) as pbar:
with mp_holistic.Holistic(
static_image_mode=False,
@ -145,6 +151,9 @@ def extract(args):
print(e)
landmarks = get_landmarks(image_orig, holistic, debug=True)
ret, image_orig = cap.read()
for _ in range(frame_skip):
ret, image_orig = cap.read()
pbar.update(1)
landmarks_video.append(landmarks)
pbar.update(1)
landmarks_video = np.vstack(landmarks_video)

View File

@ -8,7 +8,6 @@ dataset = "data/processed/spoter.csv"
# read the dataset
df = pd.read_csv(dataset)
df = map_blazepose_df(df)
with open("data/sign_to_prediction_index_map.json", "r") as f:
sign_to_prediction_index_max = json.load(f)

View File

@ -1,7 +1,6 @@
pandas
bokeh==2.4.3
boto3>=1.9
clearml==1.6.4
ipywidgets==8.0.4
matplotlib==3.5.3
mediapipe==0.8.11
@ -9,6 +8,8 @@ notebook==6.5.2
opencv-python==4.6.0.66
plotly==5.11.0
scikit-learn==1.0.2
torch
torchvision
clearml==1.10.3
torch==2.0.0
torchvision==0.15.1
tqdm==4.54.1
optuna==3.1.1

View File

@ -15,7 +15,7 @@ from torchvision import transforms
from torch.utils.data import DataLoader
from pathlib import Path
import copy
import numpy as np
from datasets import CzechSLRDataset, SLREmbeddingDataset, collate_fn_triplet_padd, collate_fn_padd
from models import SPOTER, SPOTER_EMBEDDINGS, train_epoch, evaluate, train_epoch_embedding, \
train_epoch_embedding_online, evaluate_embedding
@ -32,7 +32,7 @@ except ImportError:
pass
PROJECT_NAME = "spoter"
PROJECT_NAME = "SpoterEmbedding"
CLEARML = "clearml"
@ -75,12 +75,25 @@ def train(args, tracker: Tracker):
# Construct the model
if not args.classification_model:
slrt_model = SPOTER_EMBEDDINGS(
features=args.vector_length,
hidden_dim=args.hidden_dim,
norm_emb=args.normalize_embeddings,
dropout=args.dropout
)
# if finetune, load the weights from the classification model
if args.finetune:
checkpoint = torch.load(args.checkpoint_path, map_location=device)
slrt_model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
dropout=checkpoint["config_args"].dropout,
)
else:
slrt_model = SPOTER_EMBEDDINGS(
features=args.vector_length,
hidden_dim=args.hidden_dim,
norm_emb=args.normalize_embeddings,
dropout=args.dropout,
)
model_type = 'embed'
if args.hard_triplet_mining == "None":
cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2)

View File

@ -1,22 +1,23 @@
#!/bin/sh
python -m train \
--save_checkpoints_every 10 \
--experiment_name "augment_rotate_75_x8" \
--epochs 300 \
python3 -m train \
--save_checkpoints_every 1 \
--experiment_name "Finetune Basic Signs" \
--epochs 100 \
--optimizer "ADAM" \
--lr 0.001 \
--lr 0.00001 \
--batch_size 16 \
--dataset_name "processed" \
--training_set_path "spoter_train.csv" \
--validation_set_path "spoter_test.csv" \
--dataset_name "BasicSigns" \
--training_set_path "train.csv" \
--validation_set_path "val.csv" \
--vector_length 32 \
--epoch_iters -1 \
--scheduler_factor 0 \
--hard_triplet_mining "in_batch" \
--scheduler_factor 0.05 \
--hard_triplet_mining "None" \
--filter_easy_triplets \
--triplet_loss_margin 1 \
--triplet_loss_margin 2 \
--dropout 0.2 \
--augmentations_prob=0.75 \
--hard_mining_scheduler_triplets_threshold=0 \
--normalize_embeddings \
--num_classes 100 \
--tracker=clearml \
--dataset_loader=clearml \
--dataset_project="SpoterEmbedding" \
--finetune \
--checkpoint_path "checkpoints/checkpoint_embed_3006.pth"

View File

@ -81,4 +81,7 @@ def get_default_args():
help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \
distance threshold of the batch sorter")
parser.add_argument("--finetune", action='store_true', default=False, help="Fintune the model")
parser.add_argument("--checkpoint_path", type=str, default="")
return parser

1636
visualize_data.ipynb Normal file

File diff suppressed because one or more lines are too long

356
webcam.py
View File

@ -1,338 +1,54 @@
from collections import Counter
import cv2
import mediapipe as mp
import numpy as np
import pandas as pd
import torch
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
from models import SPOTER_EMBEDDINGS
from predictions.k_nearest import KNearestNeighbours
from predictions.predictor import Predictor
from predictions.svm_model import SVM
# Initialize MediaPipe Hands model
holistic = mp.solutions.holistic.Holistic(
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
model_complexity=2
)
mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils
if __name__ == '__main__':
buffer = []
# open webcam stream
cap = cv2.VideoCapture(0)
BODY_IDENTIFIERS = [
0,
33,
5,
2,
8,
7,
12,
11,
14,
13,
16,
15,
]
HAND_IDENTIFIERS = [
0,
8,
7,
6,
5,
12,
11,
10,
9,
16,
15,
14,
13,
20,
19,
18,
17,
4,
3,
2,
1,
]
def extract_keypoints(image_orig):
image = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
results = holistic.process(image)
def extract_keypoints(lmks):
if lmks:
a = np.array([[float(lmk.x), float(lmk.y)] for lmk in lmks.landmark])
return a
return None
def calculate_neck(keypoints):
left_shoulder = keypoints[11]
right_shoulder = keypoints[12]
neck = [(float(left_shoulder[0]) + float(right_shoulder[0])) / 2, (float(left_shoulder[1]) + float(right_shoulder[1])) / 2]
# add neck to keypoints
keypoints = np.append(keypoints, [neck], axis=0)
return keypoints
pose = extract_keypoints(results.pose_landmarks)
pose = calculate_neck(pose)
pose_norm = normalize_pose(pose)
# filter out keypoints that are not in BODY_IDENTIFIERS and make sure they are in the correct order
pose_norm = pose_norm[BODY_IDENTIFIERS]
left_hand = extract_keypoints(results.left_hand_landmarks)
right_hand = extract_keypoints(results.right_hand_landmarks)
if left_hand is None and right_hand is None:
return None
# normalize hands
if left_hand is not None:
left_hand = normalize_hand(left_hand)
type_predictor = "svm"
if type_predictor == "knn":
k = 10
predictor_type = KNearestNeighbours(k)
elif type_predictor == "svm":
predictor_type = SVM()
else:
left_hand = np.zeros((21, 2))
if right_hand is not None:
right_hand = normalize_hand(right_hand)
else:
right_hand = np.zeros((21, 2))
left_hand = left_hand[HAND_IDENTIFIERS]
right_hand = right_hand[HAND_IDENTIFIERS]
# combine pose and hands
pose_norm = np.append(pose_norm, left_hand, axis=0)
pose_norm = np.append(pose_norm, right_hand, axis=0)
# move interval
pose_norm -= 0.5
return pose_norm
predictor_type = KNearestNeighbours(1)
buffer = []
left_shoulder_index = 11
right_shoulder_index = 12
neck_index = 33
nose_index = 0
left_eye_index = 2
# embeddings_path = 'embeddings/basic-signs/embeddings.csv'
embeddings_path = 'embeddings/fingerspelling/embeddings.csv'
# if we have the keypoints, normalize single body, keypoints is numpy array of (identifiers, 2)
def normalize_pose(keypoints):
left_shoulder = keypoints[left_shoulder_index]
right_shoulder = keypoints[right_shoulder_index]
predictor = Predictor(embeddings_path, predictor_type)
neck = keypoints[neck_index]
nose = keypoints[nose_index]
index = 0
# Prevent from even starting the analysis if some necessary elements are not present
if (left_shoulder[0] == 0 or right_shoulder[0] == 0
or (left_shoulder[0] == right_shoulder[0] and left_shoulder[1] == right_shoulder[1])) and (
neck[0] == 0 or nose[0] == 0 or (neck[0] == nose[0] and neck[1] == nose[1])):
return keypoints
while cap.isOpened():
# Wait for key press to exit
if cv2.waitKey(5) & 0xFF == 27:
break
if left_shoulder[0] != 0 and right_shoulder[0] != 0 and (left_shoulder[0] != right_shoulder[0] or left_shoulder[1] != right_shoulder[1]):
shoulder_distance = ((((left_shoulder[0] - right_shoulder[0]) ** 2) + ((left_shoulder[1] - right_shoulder[1]) ** 2)) ** 0.5)
head_metric = shoulder_distance
else:
neck_nose_distance = ((((neck[0] - nose[0]) ** 2) + ((neck[1] - nose[1]) ** 2)) ** 0.5)
head_metric = neck_nose_distance
ret, frame = cap.read()
pose = predictor.extract_keypoints(frame)
# Set the starting and ending point of the normalization bounding box
starting_point = [keypoints[neck_index][0] - 3 * head_metric, keypoints[left_eye_index][1] + head_metric]
ending_point = [keypoints[neck_index][0] + 3 * head_metric, starting_point[1] - 6 * head_metric]
if pose is None:
cv2.imshow('MediaPipe Hands', frame)
continue
if starting_point[0] < 0:
starting_point[0] = 0
if starting_point[1] < 0:
starting_point[1] = 0
if ending_point[0] < 0:
ending_point[0] = 0
if ending_point[1] < 0:
ending_point[1] = 0
buffer.append(pose)
if len(buffer) > 15:
buffer.pop(0)
# Normalize the keypoints
for i in range(len(keypoints)):
keypoints[i][0] = (keypoints[i][0] - starting_point[0]) / (ending_point[0] - starting_point[0])
keypoints[i][1] = (keypoints[i][1] - ending_point[1]) / (starting_point[1] - ending_point[1])
if len(buffer) == 15:
label, score = predictor.make_prediction(buffer)
return keypoints
# draw label
cv2.putText(frame, str(label), (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
def normalize_hand(keypoints):
x_values = [keypoints[i][0] for i in range(len(keypoints)) if keypoints[i][0] != 0]
y_values = [keypoints[i][1] for i in range(len(keypoints)) if keypoints[i][1] != 0]
if not x_values or not y_values:
return keypoints
width, height = max(x_values) - min(x_values), max(y_values) - min(y_values)
if width > height:
delta_x = 0.1 * width
delta_y = delta_x + ((width - height) / 2)
else:
delta_y = 0.1 * height
delta_x = delta_y + ((height - width) / 2)
starting_point = (min(x_values) - delta_x, min(y_values) - delta_y)
ending_point = (max(x_values) + delta_x, max(y_values) + delta_y)
if ending_point[0] - starting_point[0] == 0 or ending_point[1] - starting_point[1] == 0:
return keypoints
# normalize keypoints
for i in range(len(keypoints)):
keypoints[i][0] = (keypoints[i][0] - starting_point[0]) / (ending_point[0] - starting_point[0])
keypoints[i][1] = (keypoints[i][1] - starting_point[1]) / (ending_point[1] - starting_point[1])
return keypoints
# load training embedding csv
df = pd.read_csv('embeddings/basic-signs/embeddings.csv')
def minkowski_distance_p(x, y, p=2):
x = np.asarray(x)
y = np.asarray(y)
# Find smallest common datatype with float64 (return type of this
# function) - addresses #10262.
# Don't just cast to float64 for complex input case.
common_datatype = np.promote_types(np.promote_types(x.dtype, y.dtype),
'float64')
# Make sure x and y are NumPy arrays of correct datatype.
x = x.astype(common_datatype)
y = y.astype(common_datatype)
if p == np.inf:
return np.amax(np.abs(y-x), axis=-1)
elif p == 1:
return np.sum(np.abs(y-x), axis=-1)
else:
return np.sum(np.abs(y-x)**p, axis=-1)
def minkowski_distance(x, y, p=2):
x = np.asarray(x)
y = np.asarray(y)
if p == np.inf or p == 1:
return minkowski_distance_p(x, y, p)
else:
return minkowski_distance_p(x, y, p)**(1./p)
def distance_matrix(keypoints, embeddings, p=2, threshold=1000000):
x = np.array(keypoints)
m, k = x.shape
y = np.asarray(embeddings)
n, kk = y.shape
if k != kk:
raise ValueError(f"x contains {k}-dimensional vectors but y contains "
f"{kk}-dimensional vectors")
if m*n*k <= threshold:
return minkowski_distance(x[:,np.newaxis,:],y[np.newaxis,:,:],p)
else:
result = np.empty((m,n),dtype=float) # FIXME: figure out the best dtype
if m < n:
for i in range(m):
result[i,:] = minkowski_distance(x[i],y,p)
else:
for j in range(n):
result[:,j] = minkowski_distance(x,y[j],p)
return result
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
).to(device)
model.load_state_dict(checkpoint["state_dict"])
embeddings = df.drop(columns=['labels', 'label_name', 'embeddings'])
# convert embedding from string to list of floats
embeddings["embeddings"] = embeddings["embeddings2"].apply(lambda x: [float(i) for i in x[1:-1].split(", ")])
# drop embeddings2
embeddings = embeddings.drop(columns=['embeddings2'])
# to list
embeddings = embeddings["embeddings"].tolist()
def make_prediction(keypoints):
# run model on frame
model.eval()
with torch.no_grad():
keypoints = torch.from_numpy(np.array([keypoints])).float().to(device)
new_embeddings = model(keypoints).cpu().numpy().tolist()[0]
# calculate distance matrix
dist_matrix = distance_matrix(new_embeddings, embeddings, p=2, threshold=1000000)
# get the 5 closest matches and select the class that is most common and use the average distance as the score
# get the 5 closest matches
indeces = np.argsort(dist_matrix)[0][:5]
# get the labels
labels = df["label_name"].iloc[indeces].tolist()
c = Counter(labels).most_common()[0][0]
# filter indeces to only include the most common label
indeces = [i for i in indeces if df["label_name"].iloc[i] == c]
# get the average distance
score = np.mean(dist_matrix[0][indeces])
return c, score
# open webcam stream
cap = cv2.VideoCapture(0)
while cap.isOpened():
# read frame
ret, frame = cap.read()
pose = extract_keypoints(frame)
if pose is None:
# Show the frame
cv2.imshow('MediaPipe Hands', frame)
continue
buffer.append(pose)
if len(buffer) > 15:
buffer.pop(0)
if len(buffer) == 15:
label, score = make_prediction(buffer)
# draw label
cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
# Show the frame
cv2.imshow('MediaPipe Hands', frame)
# Wait for key press to exit
if cv2.waitKey(5) & 0xFF == 27:
break
# open video A.mp4
# cap = cv2.VideoCapture('E.mp4')
# while cap.isOpened():
# # read frame
# ret, frame = cap.read()
# if frame is None:
# break
# pose = extract_keypoints(frame)
# buffer.append(pose)
# label, score = make_prediction(buffer)
# print(label, score)