Draft: Multiple prediction methods #2
1
.gitignore
vendored
1
.gitignore
vendored
@ -155,3 +155,4 @@ out-img/
|
||||
converted_models/
|
||||
*.pth
|
||||
*.onnx
|
||||
.devcontainer
|
||||
|
||||
20
README2.md
Normal file
20
README2.md
Normal file
@ -0,0 +1,20 @@
|
||||
# Spoter Embeddings
|
||||
|
||||
## Creating dataset
|
||||
First, make a folder where all you're videos are located. When this is done, all keypoints can be extracted from the videos using the following command. This will extract the keypoints and store them in \<path-to-landmarks-folder\>.
|
||||
```
|
||||
python3 preprocessing.py extract --videos-folder <path-to-videos-folder> --output-landmark <path-to-landmarks-folder>
|
||||
```
|
||||
|
||||
When this is done, the dataset can be created using the following command:
|
||||
```
|
||||
python3 preprocessing.py create --landmarks-dataset <path-to-landmarks-folder> --videos-folder <path-to-videos-folder> --dataset-folder <dataset-output-folder> (--create-new-split --test-size <test-percentage>)
|
||||
```
|
||||
The above command generates a train (and val) csv file which includes all the extracted keypoints. These can then be used to train or generates embeddings.
|
||||
|
||||
## Creating Embeddings
|
||||
The embeddings can be created using the following command:
|
||||
```
|
||||
python3 export_embeddings.py --checkpoint <path-to-checkpoints-file> --dataset <path-to-dataset-file> --output <embeddings-output-file>
|
||||
```
|
||||
The command above generates the embeddings for a given dataset and saves them as a csv file.
|
||||
BIN
checkpoints/checkpoint_embed_3006.pth
Normal file
BIN
checkpoints/checkpoint_embed_3006.pth
Normal file
Binary file not shown.
BIN
checkpoints/checkpoint_embed_3835.pth
Normal file
BIN
checkpoints/checkpoint_embed_3835.pth
Normal file
Binary file not shown.
91
embeddings/basic-signs/embeddings-online-dict.csv
Normal file
91
embeddings/basic-signs/embeddings-online-dict.csv
Normal file
@ -0,0 +1,91 @@
|
||||
embeddings,label_name,labels,embeddings2
|
||||
"[[ 0.01295076 -0.1074132 0.5111911 0.08452003 -0.7638924 0.6458221
|
||||
-1.3892679 -1.1791427 -0.30736607 -0.41543546 -0.6358013 0.31411174
|
||||
0.878936 1.7265923 -0.09298562 0.12005516 0.26995468 -1.5934305
|
||||
-0.12619524 0.9111336 0.91827893 0.5948979 0.90334046 -0.84089845
|
||||
-0.29575598 -0.3024254 -0.03074584 1.4402957 -0.22309914 0.54750854
|
||||
-0.68767273 -0.0665718 ]]",BEDANKEN,0,"[0.012950755655765533, -0.10741320252418518, 0.5111911296844482, 0.08452003449201584, -0.763892412185669, 0.6458221077919006, -1.389267921447754, -1.179142713546753, -0.3073660731315613, -0.41543546319007874, -0.6358013153076172, 0.31411173939704895, 0.8789359927177429, 1.7265923023223877, -0.0929856151342392, 0.12005516141653061, 0.2699546813964844, -1.593430519104004, -0.12619523704051971, 0.9111335873603821, 0.9182789325714111, 0.5948979258537292, 0.9033404588699341, -0.8408984541893005, -0.2957559823989868, -0.30242541432380676, -0.030745841562747955, 1.440295696258545, -0.22309914231300354, 0.5475085377693176, -0.6876727342605591, -0.06657180190086365]"
|
||||
"[[ 1.0287632 -0.94657993 0.7852507 -0.37169546 0.3614158 0.5360078
|
||||
-0.9989285 -0.24402924 -0.47437504 -0.17307773 0.9196662 0.5362411
|
||||
-0.50366765 0.7600024 0.6701926 -0.5288625 -0.38939306 -1.0722619
|
||||
-1.33836 1.331829 -0.6958481 0.57767504 0.87002045 -0.59203666
|
||||
-0.57576096 0.70008135 0.8224436 -0.32370126 0.5276206 -0.62150276
|
||||
-0.9639951 -0.40879178]]",GOED,1,"[1.0287631750106812, -0.9465799331665039, 0.785250723361969, -0.37169545888900757, 0.3614158034324646, 0.536007821559906, -0.9989284873008728, -0.244029238820076, -0.47437503933906555, -0.17307773232460022, 0.9196661710739136, 0.5362411141395569, -0.5036676526069641, 0.7600023746490479, 0.6701925992965698, -0.528862476348877, -0.38939306139945984, -1.072261929512024, -1.3383599519729614, 1.3318289518356323, -0.6958481073379517, 0.5776750445365906, 0.8700204491615295, -0.5920366644859314, -0.5757609605789185, 0.7000813484191895, 0.8224436044692993, -0.32370126247406006, 0.5276206135749817, -0.6215027570724487, -0.963995099067688, -0.40879178047180176]"
|
||||
"[[-0.5988377 -1.6454532 0.5696761 0.27801913 -0.39687645 1.5192578
|
||||
-0.57908726 -1.9851501 1.1918951 -1.3280408 0.27390718 0.71642965
|
||||
-1.0471189 0.9825219 -0.6625729 -0.28641602 2.0109527 -0.97667754
|
||||
-2.3978348 2.666861 0.15066166 -1.6922158 0.9681514 -1.1413386
|
||||
0.08027886 -0.6215762 -0.09759905 0.6157277 -0.01341684 0.9806912
|
||||
-0.09793527 -1.9003383 ]]",GOEDEMIDDAG,2,"[-0.598837673664093, -1.6454532146453857, 0.5696761012077332, 0.27801913022994995, -0.3968764543533325, 1.5192577838897705, -0.5790872573852539, -1.9851500988006592, 1.1918951272964478, -1.3280408382415771, 0.2739071846008301, 0.7164296507835388, -1.047118902206421, 0.9825218915939331, -0.6625729203224182, -0.28641602396965027, 2.0109527111053467, -0.9766775369644165, -2.3978347778320312, 2.666861057281494, 0.1506616622209549, -1.6922158002853394, 0.9681513905525208, -1.141338586807251, 0.08027885854244232, -0.621576189994812, -0.09759905189275742, 0.6157277226448059, -0.013416841626167297, 0.9806911945343018, -0.0979352742433548, -1.9003382921218872]"
|
||||
"[[-0.11578767 0.28371835 0.8604608 -0.31061298 -0.78153455 -0.24914108
|
||||
-2.443886 -0.3267159 0.29237527 -0.7879326 -0.12082034 0.32809633
|
||||
-1.0938975 0.7125962 0.49117383 -0.3119098 0.15369919 0.39569452
|
||||
-1.18983 1.2498406 -0.38752007 -0.54896545 -0.620718 0.02654967
|
||||
-0.10994279 -0.39713857 -0.18035033 -1.554728 0.12178666 -0.5698373
|
||||
1.1097438 -1.2350351 ]]",GOEDEMORGEN,3,"[-0.11578767001628876, 0.2837183475494385, 0.8604608178138733, -0.3106129765510559, -0.7815345525741577, -0.24914108216762543, -2.4438860416412354, -0.326715886592865, 0.29237526655197144, -0.7879325747489929, -0.12082034349441528, 0.32809633016586304, -1.0938974618911743, 0.7125961780548096, 0.4911738336086273, -0.3119097948074341, 0.15369918942451477, 0.3956945240497589, -1.18982994556427, 1.2498406171798706, -0.38752007484436035, -0.5489654541015625, -0.6207180023193359, 0.02654966711997986, -0.10994279384613037, -0.3971385657787323, -0.18035033345222473, -1.5547280311584473, 0.12178666144609451, -0.5698372721672058, 1.1097438335418701, -1.2350350618362427]"
|
||||
"[[-0.9557147 0.03368077 1.0923759 -0.41077584 -1.6992741 -0.59566665
|
||||
-1.1562454 0.8824008 -0.25122 0.625515 1.6720039 0.5639413
|
||||
-0.35894975 -0.05896414 0.84854823 0.4072139 -1.1763096 -1.5435666
|
||||
-1.5066593 0.66299886 -1.3121625 1.4410669 -0.6341419 -0.4873513
|
||||
0.8113033 -0.01502099 0.32262453 2.5855515 -0.7294207 0.7582124
|
||||
-0.1386989 -0.17011487]]",GOEDENACHT,4,"[-0.9557147026062012, 0.03368076682090759, 1.0923758745193481, -0.4107758402824402, -1.6992740631103516, -0.5956666469573975, -1.1562453508377075, 0.8824008107185364, -0.2512199878692627, 0.6255149841308594, 1.6720038652420044, 0.5639412999153137, -0.35894975066185, -0.058964136987924576, 0.8485482335090637, 0.40721389651298523, -1.176309585571289, -1.5435665845870972, -1.5066592693328857, 0.6629988551139832, -1.3121625185012817, 1.441066861152649, -0.6341419219970703, -0.48735129833221436, 0.8113033175468445, -0.015020990744233131, 0.322624534368515, 2.5855515003204346, -0.7294207215309143, 0.7582123875617981, -0.13869890570640564, -0.17011487483978271]"
|
||||
"[[ 0.13687058 -1.099074 1.1867181 -1.2216619 -0.965039 -0.6957289
|
||||
-1.6400954 0.7317837 0.5887569 0.18756844 1.4867041 0.63357025
|
||||
-1.5853169 0.4157976 0.76919526 0.08512082 -1.0937556 -0.07339763
|
||||
-2.8580015 1.6835626 -1.4538742 1.2302468 -0.49323368 -0.81243044
|
||||
-0.11378415 -0.09592562 0.31165385 -0.32617894 0.02981896 -0.01485211
|
||||
1.1880591 -0.40726334]]",GOEDENAVOND,5,"[0.1368705779314041, -1.0990740060806274, 1.1867181062698364, -1.221661925315857, -0.9650390148162842, -0.6957288980484009, -1.6400953531265259, 0.7317836880683899, 0.5887569189071655, 0.18756844103336334, 1.4867041110992432, 0.6335702538490295, -1.5853168964385986, 0.4157975912094116, 0.7691952586174011, 0.08512081950902939, -1.0937556028366089, -0.07339762896299362, -2.858001470565796, 1.6835626363754272, -1.4538742303848267, 1.2302467823028564, -0.49323368072509766, -0.8124304413795471, -0.11378414928913116, -0.09592562168836594, 0.31165385246276855, -0.3261789381504059, 0.029818959534168243, -0.014852114021778107, 1.1880590915679932, -0.4072633385658264]"
|
||||
"[[ 1.3673804 1.7083405 0.58775777 0.26056585 -0.29101595 -0.69954723
|
||||
0.45681733 1.6908163 0.02684825 0.7537957 2.2054908 0.28831166
|
||||
-1.5712252 -1.6869793 0.8009781 -0.51264 -1.3544708 -0.72502786
|
||||
-0.31009498 -0.23777717 -1.7524906 1.6333568 1.5263942 0.22317217
|
||||
-0.40576094 2.2136266 1.4030526 -1.0599502 0.8686069 -1.1141618
|
||||
-0.01899157 -1.2256656 ]]",JA,6,"[1.3673803806304932, 1.7083405256271362, 0.5877577662467957, 0.260565847158432, -0.29101595282554626, -0.6995472311973572, 0.4568173289299011, 1.6908162832260132, 0.02684824913740158, 0.7537956833839417, 2.205490827560425, 0.2883116602897644, -1.5712251663208008, -1.6869792938232422, 0.8009781241416931, -0.5126399993896484, -1.3544708490371704, -0.725027859210968, -0.3100949823856354, -0.23777717351913452, -1.7524906396865845, 1.6333568096160889, 1.526394248008728, 0.2231721729040146, -0.4057609438896179, 2.2136266231536865, 1.403052568435669, -1.0599502325057983, 0.8686069250106812, -1.1141618490219116, -0.01899157091975212, -1.22566556930542]"
|
||||
"[[ 1.8208723 -0.59625745 0.15060106 -0.23053254 1.0344827 1.7016335
|
||||
0.1968202 -0.48188782 0.3004466 -0.53336793 0.89704275 0.26869062
|
||||
-1.0112952 0.0343281 0.32149622 -0.8189871 0.33571526 -0.57059044
|
||||
-0.7577773 1.4743904 -0.01735806 -0.4116615 2.3637483 -0.6602343
|
||||
-0.6101709 1.8139238 1.041902 -0.9006022 0.6082014 -0.23616463
|
||||
-0.62204087 -0.524899 ]]",LINKS,7,"[1.8208723068237305, -0.5962574481964111, 0.15060105919837952, -0.23053254187107086, 1.034482717514038, 1.7016334533691406, 0.1968201994895935, -0.4818878173828125, 0.30044659972190857, -0.533367931842804, 0.8970427513122559, 0.2686906158924103, -1.011295199394226, 0.034328099340200424, 0.32149621844291687, -0.8189870715141296, 0.33571526408195496, -0.5705904364585876, -0.7577772736549377, 1.4743903875350952, -0.01735806092619896, -0.4116615056991577, 2.36374831199646, -0.660234272480011, -0.6101709008216858, 1.8139238357543945, 1.04190194606781, -0.9006022214889526, 0.6082013845443726, -0.2361646294593811, -0.622040867805481, -0.5248990058898926]"
|
||||
"[[ 1.3968729 2.340378 0.567814 0.5684975 -0.8795973 -1.0090083
|
||||
-0.01301649 1.9747832 -0.3257468 0.76960254 1.402801 0.39424965
|
||||
0.06948688 -0.51904166 1.0961802 -0.34335235 -1.8730735 -1.0801809
|
||||
0.2751789 -0.96998405 -1.9822828 2.1046805 1.3094746 0.5770473
|
||||
-0.19693638 1.8973799 0.99536693 0.2668684 0.28499565 -0.9838486
|
||||
-0.19578832 -0.46982569]]",NEE,8,"[1.396872878074646, 2.3403780460357666, 0.5678139925003052, 0.5684974789619446, -0.8795973062515259, -1.0090082883834839, -0.013016488403081894, 1.974783182144165, -0.3257468044757843, 0.7696025371551514, 1.4028010368347168, 0.39424964785575867, 0.06948687881231308, -0.5190416574478149, 1.0961802005767822, -0.343352347612381, -1.8730734586715698, -1.0801808834075928, 0.2751789093017578, -0.9699840545654297, -1.9822827577590942, 2.1046805381774902, 1.3094745874404907, 0.5770472884178162, -0.19693638384342194, 1.8973798751831055, 0.9953669309616089, 0.2668684124946594, 0.2849956452846527, -0.9838485717773438, -0.19578832387924194, -0.4698256850242615]"
|
||||
"[[ 1.8143647 -0.9315294 0.28706568 -1.0938002 -0.2451038 0.42331484
|
||||
-0.31653756 0.00268575 0.89822304 1.0639151 1.7406261 1.1707302
|
||||
-1.6639041 0.30558044 0.22984704 -1.0260515 -0.08818819 -0.14696571
|
||||
-1.5269523 0.55015475 -0.12936743 0.93951946 1.3917109 -0.62517196
|
||||
-1.3869016 1.4833041 1.4647726 -1.0285482 -0.3704591 1.4672718
|
||||
0.40315557 0.04754414]]",RECHTS,9,"[1.8143646717071533, -0.9315294027328491, 0.28706568479537964, -1.0938001871109009, -0.24510380625724792, 0.4233148396015167, -0.3165375590324402, 0.002685748040676117, 0.8982230424880981, 1.0639151334762573, 1.7406260967254639, 1.1707302331924438, -1.663904070854187, 0.3055804371833801, 0.2298470437526703, -1.0260515213012695, -0.08818819373846054, -0.14696571230888367, -1.5269522666931152, 0.5501547455787659, -0.1293674260377884, 0.939519464969635, 1.391710877418518, -0.625171959400177, -1.386901617050171, 1.4833041429519653, 1.4647725820541382, -1.028548240661621, -0.37045910954475403, 1.4672718048095703, 0.4031555652618408, 0.04754413664340973]"
|
||||
"[[ 1.681777 -0.79441184 0.07110912 -0.01304399 1.110081 2.1115103
|
||||
-0.27283993 -1.5035654 0.55927217 -0.9903696 0.05718866 0.19891238
|
||||
-0.7289791 0.71738267 -0.41428873 -0.58389515 1.1087953 -0.6616588
|
||||
-0.6330258 1.5830203 0.60013187 -1.0640136 2.3917787 -0.70012283
|
||||
-0.7359694 0.9847692 0.94458187 -0.5915589 0.3662199 0.39359796
|
||||
-0.6499281 -0.30565363]]",SALUUT,10,"[1.681777000427246, -0.794411838054657, 0.07110912352800369, -0.013043994084000587, 1.1100809574127197, 2.1115102767944336, -0.2728399336338043, -1.5035654306411743, 0.5592721700668335, -0.9903696179389954, 0.05718865618109703, 0.1989123821258545, -0.7289791107177734, 0.7173826694488525, -0.414288729429245, -0.5838951468467712, 1.1087952852249146, -0.6616588234901428, -0.6330258250236511, 1.5830203294754028, 0.6001318693161011, -1.0640136003494263, 2.3917787075042725, -0.7001228332519531, -0.7359694242477417, 0.9847692251205444, 0.9445818662643433, -0.5915588736534119, 0.3662199079990387, 0.39359796047210693, -0.649928092956543, -0.3056536316871643]"
|
||||
"[[ 0.56773967 -0.6260798 0.23771092 -0.10016712 0.86817515 0.92371875
|
||||
-0.16313902 -0.3785063 -0.41955388 -0.586288 -0.33712262 -0.07999519
|
||||
-0.98128295 0.18787864 -0.36432508 -0.06605241 0.00710656 -0.36359936
|
||||
-0.00642961 1.257384 0.64647824 -1.1618687 1.3707273 -0.46546304
|
||||
-0.14522405 0.29306158 0.41898865 -0.12311607 0.24604419 -0.48434448
|
||||
-0.79076517 0.753163 ]]",SLECHT,11,"[0.5677396655082703, -0.626079797744751, 0.23771092295646667, -0.10016711801290512, 0.8681751489639282, 0.9237187504768372, -0.16313901543617249, -0.37850630283355713, -0.41955387592315674, -0.5862879753112793, -0.3371226191520691, -0.07999519258737564, -0.9812829494476318, 0.18787863850593567, -0.36432507634162903, -0.06605241447687149, 0.007106557488441467, -0.36359935998916626, -0.006429608445614576, 1.257383942604065, 0.6464782357215881, -1.161868691444397, 1.370727300643921, -0.4654630422592163, -0.14522404968738556, 0.29306158423423767, 0.4189886450767517, -0.12311606854200363, 0.24604418873786926, -0.484344482421875, -0.7907651662826538, 0.7531629800796509]"
|
||||
"[[ 0.9251696 -2.8536968 0.6521715 -1.8256427 0.44136596 0.33299872
|
||||
-1.4674346 -0.7897836 0.47153538 0.1437631 0.7396151 1.0851275
|
||||
-1.2100328 1.626544 0.69652647 0.0048107 0.13927785 -0.63199776
|
||||
-3.0121155 2.1738136 -0.7935879 0.2744053 0.281097 -1.0899199
|
||||
-0.7062637 -0.04824958 0.3436054 -0.69366074 0.22799876 0.8932708
|
||||
0.51823634 -0.08065546]]",SMAKELIJK,12,"[0.9251695871353149, -2.853696823120117, 0.6521714925765991, -1.825642704963684, 0.44136595726013184, 0.33299872279167175, -1.4674346446990967, -0.7897835969924927, 0.4715353846549988, 0.14376309514045715, 0.7396150827407837, 1.0851274728775024, -1.2100328207015991, 1.6265439987182617, 0.6965264678001404, 0.004810698330402374, 0.139277845621109, -0.6319977641105652, -3.012115478515625, 2.173813581466675, -0.7935879230499268, 0.274405300617218, 0.2810969948768616, -1.089919924736023, -0.7062637209892273, -0.04824957996606827, 0.3436053991317749, -0.6936607360839844, 0.2279987633228302, 0.8932707905769348, 0.5182363390922546, -0.08065545558929443]"
|
||||
"[[-1.460053 2.9393604 0.5533726 1.3786266 -1.5367306 -1.2867874
|
||||
-1.2442408 0.62938243 -0.7092637 -1.0278705 -1.3296479 -0.8105969
|
||||
0.69997054 0.16618343 0.5113248 0.15563956 -0.4815752 -0.10130378
|
||||
1.4232539 -1.2767068 -0.3229611 0.45140803 -1.3901687 1.0403453
|
||||
1.454544 -1.308266 -1.145199 0.661388 -0.20519465 -1.0480955
|
||||
-0.04290716 -0.36275834]]",SORRY,13,"[-1.4600529670715332, 2.9393603801727295, 0.5533726215362549, 1.3786265850067139, -1.5367306470870972, -1.2867873907089233, -1.2442407608032227, 0.6293824315071106, -0.7092636823654175, -1.027870535850525, -1.3296478986740112, -0.8105968832969666, 0.699970543384552, 0.16618342697620392, 0.5113248229026794, 0.15563955903053284, -0.4815751910209656, -0.10130377858877182, 1.4232538938522339, -1.2767068147659302, -0.32296109199523926, 0.4514080286026001, -1.3901686668395996, 1.040345311164856, 1.454543948173523, -1.308266043663025, -1.145198941230774, 0.6613879799842834, -0.2051946520805359, -1.048095464706421, -0.04290715977549553, -0.3627583384513855]"
|
||||
"[[ 1.0733138 -0.66348886 0.36737278 0.01765811 0.5730918 1.7617474
|
||||
-0.45580968 -1.0973133 0.4730748 -0.4665998 0.48594972 0.548745
|
||||
-0.91712314 0.4846773 0.3774526 -0.5996147 0.28750947 -1.0166394
|
||||
-0.6239797 1.7935088 -0.03666669 -0.51941574 2.045076 -0.9045104
|
||||
-0.6312211 0.9698636 1.0522215 -0.14772844 0.7406032 0.2901447
|
||||
-0.48710442 -0.4619298 ]]",TOT-ZIENS,14,"[1.07331383228302, -0.6634888648986816, 0.3673727810382843, 0.017658105120062828, 0.5730918049812317, 1.7617473602294922, -0.45580968260765076, -1.0973132848739624, 0.4730747938156128, -0.46659979224205017, 0.48594972491264343, 0.5487449765205383, -0.9171231389045715, 0.4846773147583008, 0.3774526119232178, -0.599614679813385, 0.2875094711780548, -1.0166393518447876, -0.6239796876907349, 1.793508768081665, -0.036666687577962875, -0.5194157361984253, 2.0450758934020996, -0.9045103788375854, -0.6312211155891418, 0.9698635935783386, 1.0522215366363525, -0.14772844314575195, 0.7406032085418701, 0.2901447117328644, -0.4871044158935547, -0.4619297981262207]"
|
||||
|
97
export_embeddings.py
Normal file
97
export_embeddings.py
Normal file
@ -0,0 +1,97 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
from datasets.dataset_loader import LocalDatasetLoader
|
||||
from datasets.embedding_dataset import SLREmbeddingDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from datasets import SLREmbeddingDataset, collate_fn_padd
|
||||
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
|
||||
import numpy as np
|
||||
import random
|
||||
import pandas as pd
|
||||
|
||||
seed = 43
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.use_deterministic_algorithms(True)
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
import os
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Export embeddings')
|
||||
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
|
||||
parser.add_argument('--output', type=str, default=None, help='Path to output')
|
||||
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
|
||||
parser.add_argument('--format', type=str, default='csv', help='Format of the output file (csv, json)')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = parse_args()
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
||||
# load the model
|
||||
checkpoint = torch.load(args.checkpoint, map_location=device)
|
||||
|
||||
model = SPOTER_EMBEDDINGS(
|
||||
features=checkpoint["config_args"].vector_length,
|
||||
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||
).to(device)
|
||||
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
|
||||
dataset_loader = LocalDatasetLoader()
|
||||
dataset = SLREmbeddingDataset(args.dataset, triplet=False, augmentations=False)
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn_padd,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
#num_workers=0, # Uncomment this line (and comment out next line) if you want to disable multithreading
|
||||
num_workers=multiprocessing.cpu_count(),
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
k = 0
|
||||
with torch.no_grad():
|
||||
for i, (inputs, labels, masks) in enumerate(data_loader):
|
||||
k += 1
|
||||
inputs = inputs.to(device)
|
||||
masks = masks.to(device)
|
||||
outputs = model(inputs, masks)
|
||||
|
||||
for n in range(outputs.shape[0]):
|
||||
embeddings.append(outputs[n].cpu().numpy())
|
||||
|
||||
df = pd.read_csv(args.dataset)
|
||||
df["embeddings"] = embeddings
|
||||
df = df[['embeddings', 'label_name', 'labels']]
|
||||
df['embeddings2'] = df['embeddings'].apply(lambda x: x.tolist()[0])
|
||||
|
||||
if args.format == 'json':
|
||||
df.to_json(args.output, orient='records')
|
||||
elif args.format == 'csv':
|
||||
df.to_csv(args.output, index=False)
|
||||
71
export_model.py
Normal file
71
export_model.py
Normal file
@ -0,0 +1,71 @@
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
|
||||
|
||||
# set parameters of the model
|
||||
model_name = 'embedding_model'
|
||||
output=32
|
||||
|
||||
# load PyTorch model from .pth file
|
||||
|
||||
device = torch.device("cpu")
|
||||
# if torch.cuda.is_available():
|
||||
# device = torch.device("cuda")
|
||||
|
||||
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
|
||||
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
||||
|
||||
model = SPOTER_EMBEDDINGS(
|
||||
features=checkpoint["config_args"].vector_length,
|
||||
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||
).to(device)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
# set model to evaluation mode
|
||||
model.eval()
|
||||
|
||||
model_export = "onnx"
|
||||
if model_export == "coreml":
|
||||
dummy_input = torch.randn(1, 10, 54, 2)
|
||||
# set device for dummy input
|
||||
dummy_input = dummy_input.to(device)
|
||||
traced_model = torch.jit.trace(model, dummy_input)
|
||||
|
||||
out = traced_model(dummy_input)
|
||||
import coremltools as ct
|
||||
|
||||
# Convert to Core ML
|
||||
coreml_model = ct.convert(
|
||||
traced_model,
|
||||
inputs=[ct.TensorType(name="input", shape=dummy_input.shape)],
|
||||
)
|
||||
|
||||
# Save Core ML model
|
||||
coreml_model.save("out-models/" + model_name + ".mlmodel")
|
||||
else:
|
||||
# create dummy input tensor
|
||||
dummy_input = torch.randn(1, 10, 54, 2)
|
||||
# set device for dummy input
|
||||
dummy_input = dummy_input.to(device)
|
||||
|
||||
# export model to ONNX format
|
||||
output_file = 'models/' + model_name + '.onnx'
|
||||
torch.onnx.export(model, dummy_input, output_file, input_names=['input'], output_names=['output'])
|
||||
|
||||
torch.onnx.export(model, # model being run
|
||||
dummy_input, # model input (or a tuple for multiple inputs)
|
||||
'out-models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object)
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=9, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names = ['X'], # the model's input names
|
||||
output_names = ['Y'] # the model's output names
|
||||
)
|
||||
|
||||
|
||||
# load exported ONNX model for verification
|
||||
onnx_model = onnx.load(output_file)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
73
hyperparam_opt.py
Normal file
73
hyperparam_opt.py
Normal file
@ -0,0 +1,73 @@
|
||||
from clearml.automation import UniformParameterRange, UniformIntegerParameterRange, DiscreteParameterRange
|
||||
from clearml.automation import HyperParameterOptimizer
|
||||
from clearml.automation.optuna import OptimizerOptuna
|
||||
from optuna.pruners import HyperbandPruner, MedianPruner
|
||||
from clearml import Task
|
||||
|
||||
task = Task.init(
|
||||
project_name='SpoterEmbedding',
|
||||
task_name='Automatic Hyper-Parameter Optimization',
|
||||
task_type=Task.TaskTypes.optimizer,
|
||||
reuse_last_task_id=False
|
||||
)
|
||||
|
||||
|
||||
optimizer = HyperParameterOptimizer(
|
||||
# specifying the task to be optimized, task must be in system already so it can be cloned
|
||||
base_task_id="4504e0b3ec6745249d3d4c94d3d40652",
|
||||
# setting the hyperparameters to optimize
|
||||
hyper_parameters=[
|
||||
# epochs:
|
||||
DiscreteParameterRange('Args/epochs', [200]),
|
||||
# learning rate
|
||||
UniformParameterRange('Args/lr', 0.000001, 0.01),
|
||||
# optimizer
|
||||
DiscreteParameterRange('Args/optimizer', ['ADAM', 'SGD']),
|
||||
# vector length
|
||||
UniformIntegerParameterRange('Args/vector_length', 10, 100),
|
||||
],
|
||||
# setting the objective metric we want to maximize/minimize
|
||||
objective_metric_title='train_loss',
|
||||
objective_metric_series='loss',
|
||||
objective_metric_sign='min',
|
||||
|
||||
# setting optimizer
|
||||
optimizer_class=OptimizerOptuna,
|
||||
|
||||
# configuring optimization parameters
|
||||
execution_queue='default',
|
||||
optimization_time_limit=360,
|
||||
compute_time_limit=480,
|
||||
total_max_jobs=20,
|
||||
min_iteration_per_job=0,
|
||||
max_iteration_per_job=150000,
|
||||
pool_period_min=0.1,
|
||||
|
||||
save_top_k_tasks_only=3,
|
||||
optuna_pruner=MedianPruner(),
|
||||
|
||||
)
|
||||
|
||||
def job_complete_callback(
|
||||
job_id, # type: str
|
||||
objective_value, # type: float
|
||||
objective_iteration, # type: int
|
||||
job_parameters, # type: dict
|
||||
top_performance_job_id # type: str
|
||||
):
|
||||
print('Job completed!', job_id, objective_value, objective_iteration, job_parameters)
|
||||
if job_id == top_performance_job_id:
|
||||
print('WOOT WOOT we broke the record! Objective reached {}'.format(objective_value))
|
||||
|
||||
task.execute_remotely(queue_name='hypertuning', exit_process=True)
|
||||
|
||||
optimizer.set_report_period(0.3)
|
||||
|
||||
optimizer.start(job_complete_callback=job_complete_callback)
|
||||
|
||||
optimizer.wait()
|
||||
|
||||
top_exp = optimizer.get_top_experiments(top_k=3)
|
||||
print([t.id for t in top_exp])
|
||||
|
||||
optimizer.stop()
|
||||
@ -61,20 +61,25 @@ def map_blazepose_keypoint(column):
|
||||
return f"{mapped}_{hand}{suffix}"
|
||||
|
||||
|
||||
def map_blazepose_df(df):
|
||||
def map_blazepose_df(df, rename=True):
|
||||
to_drop = []
|
||||
if rename:
|
||||
renamings = {}
|
||||
for column in df.columns:
|
||||
mapped_column = map_blazepose_keypoint(column)
|
||||
if mapped_column:
|
||||
renamings[column] = mapped_column
|
||||
else:
|
||||
to_drop.append(column)
|
||||
df = df.rename(columns=renamings)
|
||||
|
||||
for index, row in df.iterrows():
|
||||
|
||||
sequence_size = len(row["leftEar_Y"])
|
||||
lsx = row["leftShoulder_X"]
|
||||
rsx = row["rightShoulder_X"]
|
||||
lsy = row["leftShoulder_Y"]
|
||||
rsy = row["rightShoulder_Y"]
|
||||
# convert all to list
|
||||
lsx = lsx[1:-1].split(",")
|
||||
rsx = rsx[1:-1].split(",")
|
||||
lsy = lsy[1:-1].split(",")
|
||||
rsy = rsy[1:-1].split(",")
|
||||
sequence_size = len(lsx)
|
||||
|
||||
neck_x = []
|
||||
neck_y = []
|
||||
# Treat each element of the sequence (analyzed frame) individually
|
||||
@ -84,4 +89,5 @@ def map_blazepose_df(df):
|
||||
df.loc[index, "neck_X"] = str(neck_x)
|
||||
df.loc[index, "neck_Y"] = str(neck_y)
|
||||
|
||||
df.drop(columns=to_drop, inplace=True)
|
||||
return df
|
||||
@ -5,30 +5,30 @@ import pandas as pd
|
||||
from normalization.hand_normalization import normalize_hands_full
|
||||
from normalization.body_normalization import normalize_body_full
|
||||
|
||||
DATASET_PATH = './data/wlasl'
|
||||
DATASET_PATH = './data/processed'
|
||||
# Load the dataset
|
||||
df = pd.read_csv(os.path.join(DATASET_PATH, "WLASL100_train.csv"), encoding="utf-8")
|
||||
df = pd.read_csv(os.path.join(DATASET_PATH, "spoter_train.csv"), encoding="utf-8")
|
||||
|
||||
print(df.head())
|
||||
print(df.columns)
|
||||
|
||||
# Retrieve metadata
|
||||
video_size_heights = df["video_height"].to_list()
|
||||
video_size_widths = df["video_width"].to_list()
|
||||
# video_size_heights = df["video_height"].to_list()
|
||||
# video_size_widths = df["video_width"].to_list()
|
||||
|
||||
# Delete redundant (non-related) properties
|
||||
del df["video_height"]
|
||||
del df["video_width"]
|
||||
# del df["video_height"]
|
||||
# del df["video_width"]
|
||||
|
||||
# Temporarily remove other relevant metadata
|
||||
labels = df["labels"].to_list()
|
||||
video_fps = df["fps"].to_list()
|
||||
signs = df["sign"].to_list()
|
||||
|
||||
del df["labels"]
|
||||
del df["fps"]
|
||||
del df["split"]
|
||||
del df["video_id"]
|
||||
del df["label_name"]
|
||||
del df["length"]
|
||||
del df["sign"]
|
||||
del df["path"]
|
||||
del df["participant_id"]
|
||||
del df["sequence_id"]
|
||||
|
||||
# Convert the strings into lists
|
||||
|
||||
@ -41,7 +41,7 @@ for column in df.columns:
|
||||
|
||||
# Perform the normalizations
|
||||
df = normalize_hands_full(df)
|
||||
df, invalid_row_indexes = normalize_body_full(df)
|
||||
# df, invalid_row_indexes = normalize_body_full(df)
|
||||
|
||||
# Clear lists of items from deleted rows
|
||||
# labels = [t for i, t in enumerate(labels) if i not in invalid_row_indexes]
|
||||
@ -49,6 +49,6 @@ df, invalid_row_indexes = normalize_body_full(df)
|
||||
|
||||
# Return the metadata back to the dataset
|
||||
df["labels"] = labels
|
||||
df["fps"] = video_fps
|
||||
df["sign"] = signs
|
||||
|
||||
df.to_csv(os.path.join(DATASET_PATH, "wlasl_train_norm.csv"), encoding="utf-8", index=False)
|
||||
df.to_csv(os.path.join(DATASET_PATH, "spoter_train_norm.csv"), encoding="utf-8", index=False)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 5,
|
||||
"id": "c20f7fd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -13,7 +13,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 6,
|
||||
"id": "ada032d0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -22,13 +22,12 @@
|
||||
"import os\n",
|
||||
"import os.path as op\n",
|
||||
"import pandas as pd\n",
|
||||
"import json\n",
|
||||
"import base64"
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 7,
|
||||
"id": "05682e73",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -38,7 +37,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 8,
|
||||
"id": "fede7684",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -48,7 +47,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 9,
|
||||
"id": "ce531994",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -64,7 +63,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 10,
|
||||
"id": "f4a2d672",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -87,17 +86,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 11,
|
||||
"id": "1d9db764",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<torch._C.Generator at 0x7f29f89e3ed0>"
|
||||
"<torch._C.Generator at 0x7f010919d710>"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -119,7 +118,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 12,
|
||||
"id": "71224139",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -155,7 +154,7 @@
|
||||
"# checkpoint = torch.load(model.get_weights())\n",
|
||||
"\n",
|
||||
"## Set your path to checkoint here\n",
|
||||
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_6.pth\"\n",
|
||||
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_1105.pth\"\n",
|
||||
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
|
||||
"\n",
|
||||
"model = SPOTER_EMBEDDINGS(\n",
|
||||
@ -169,27 +168,28 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 24,
|
||||
"id": "ba6b58f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SL_DATASET = 'wlasl' # or 'lsa'\n",
|
||||
"if SL_DATASET == 'wlasl':\n",
|
||||
"SL_DATASET = 'basic-signs' # or 'wlasl'\n",
|
||||
"\n",
|
||||
"if SL_DATASET == 'fingerspelling':\n",
|
||||
" dataset_name = \"fingerspelling\"\n",
|
||||
" split_dataset_path = \"fingerspelling_{}.csv\"\n",
|
||||
"elif SL_DATASET == 'wlasl':\n",
|
||||
" dataset_name = \"wlasl\"\n",
|
||||
" num_classes = 100\n",
|
||||
" split_dataset_path = \"WLASL100_train.csv\"\n",
|
||||
"else:\n",
|
||||
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
|
||||
" num_classes = 64\n",
|
||||
" split_dataset_path = \"LSA64_{}.csv\"\n",
|
||||
" \n",
|
||||
" split_dataset_path = \"WLASL100_{}.csv\"\n",
|
||||
"elif SL_DATASET == 'basic-signs':\n",
|
||||
" dataset_name = \"basic-signs\"\n",
|
||||
" split_dataset_path = \"basic-signs_{}.csv\"\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 25,
|
||||
"id": "5643a72c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -269,6 +269,8 @@
|
||||
" for i, (inputs, labels, masks) in enumerate(dataloader):\n",
|
||||
" k += 1\n",
|
||||
" inputs = inputs.to(device)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" masks = masks.to(device)\n",
|
||||
" outputs = model(inputs, masks)\n",
|
||||
" for n in range(outputs.shape[0]):\n",
|
||||
@ -285,7 +287,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(810, 810)"
|
||||
"(164, 164)"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
@ -299,7 +301,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 21,
|
||||
"id": "ab83c6e2",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
@ -311,6 +313,70 @@
|
||||
" df['embeddings'] = embeddings_split[split]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"id": "0b9fb9c2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0 [1.7327625, -3.015248, -1.4775522, -0.7505071,...\n",
|
||||
"1 [2.0936582, -0.596195, -0.7918601, -0.15896143...\n",
|
||||
"2 [-1.4007742, -0.9608915, 1.3294879, -0.5185398...\n",
|
||||
"3 [1.3280737, -3.299126, -1.0110444, -1.2528414,...\n",
|
||||
"4 [-0.071124956, -0.79259753, 0.7182858, 0.38130...\n",
|
||||
" ... \n",
|
||||
"159 [-1.5968355, 1.9617733, 0.28859574, 1.256657, ...\n",
|
||||
"160 [0.44801116, -1.8377966, 1.1004394, -1.195648,...\n",
|
||||
"161 [2.0584257, 1.6986116, 0.5129896, 0.27279535, ...\n",
|
||||
"162 [1.6695516, -2.967027, -1.5715427, -0.77170163...\n",
|
||||
"163 [1.4977738, -2.6278958, -1.6123883, -0.8420623...\n",
|
||||
"Name: embeddings, Length: 164, dtype: object\n",
|
||||
"0 TOT-ZIENS\n",
|
||||
"1 GOED\n",
|
||||
"2 GOEDENACHT\n",
|
||||
"3 NEE\n",
|
||||
"4 SLECHT\n",
|
||||
" ... \n",
|
||||
"159 SORRY\n",
|
||||
"160 GOEDEMORGEN\n",
|
||||
"161 LINKS\n",
|
||||
"162 TOT-ZIENS\n",
|
||||
"163 GOED\n",
|
||||
"Name: label_name, Length: 164, dtype: object\n",
|
||||
"0 0\n",
|
||||
"1 1\n",
|
||||
"2 2\n",
|
||||
"3 3\n",
|
||||
"4 4\n",
|
||||
" ..\n",
|
||||
"159 7\n",
|
||||
"160 5\n",
|
||||
"161 13\n",
|
||||
"162 0\n",
|
||||
"163 1\n",
|
||||
"Name: labels, Length: 164, dtype: int64\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(dfs['train'][\"embeddings\"])\n",
|
||||
"print(dfs['train'][\"label_name\"])\n",
|
||||
"print(dfs['train'][\"labels\"])\n",
|
||||
"\n",
|
||||
"# only keep these columns\n",
|
||||
"dfs['train'] = dfs['train'][['embeddings', 'label_name', 'labels']]\n",
|
||||
"\n",
|
||||
"# convert embeddings to string\n",
|
||||
"dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n",
|
||||
"\n",
|
||||
"# save the dfs['train']\n",
|
||||
"dfs['train'].to_csv(f'../data/{dataset_name}/embeddings.csv', index=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2951638d",
|
||||
@ -322,7 +388,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 23,
|
||||
"id": "7399b8ae",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -331,16 +397,16 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using centroids only\n",
|
||||
"Top-1 accuracy: 5.19 %\n",
|
||||
"Top-5 embeddings class match: 17.65 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"Top-1 accuracy: 80.00 %\n",
|
||||
"Top-5 embeddings class match: 93.33 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"\n",
|
||||
"################################\n",
|
||||
"\n",
|
||||
"Using all embeddings\n",
|
||||
"Top-1 accuracy: 5.31 %\n",
|
||||
"5-nn accuracy: 5.56 % (Picks the class that appears most often in the 5 closest embeddings)\n",
|
||||
"Top-5 embeddings class match: 15.43 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"Top-5 unique class match: 15.56 % (Picks the 5 closest distinct classes)\n",
|
||||
"Top-1 accuracy: 80.00 %\n",
|
||||
"5-nn accuracy: 80.00 % (Picks the class that appears most often in the 5 closest embeddings)\n",
|
||||
"Top-5 embeddings class match: 86.67 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"Top-5 unique class match: 93.33 % (Picks the 5 closest distinct classes)\n",
|
||||
"\n",
|
||||
"################################\n",
|
||||
"\n"
|
||||
@ -375,13 +441,13 @@
|
||||
" sorted_labels = labels[argsort]\n",
|
||||
" if sorted_labels[0] == true_label:\n",
|
||||
" top1 += 1\n",
|
||||
" if use_centroids:\n",
|
||||
" good_samples.append(df_val.loc[i, 'video_id'])\n",
|
||||
" else:\n",
|
||||
" good_samples.append((df_val.loc[i, 'video_id'],\n",
|
||||
" df_train.loc[argsort[0], 'video_id'],\n",
|
||||
" i,\n",
|
||||
" argsort[0]))\n",
|
||||
" # if use_centroids:\n",
|
||||
" # good_samples.append(df_val.loc[i, 'video_id'])\n",
|
||||
" # else:\n",
|
||||
" # good_samples.append((df_val.loc[i, 'video_id'],\n",
|
||||
" # df_train.loc[argsort[0], 'video_id'],\n",
|
||||
" # i,\n",
|
||||
" # argsort[0]))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n",
|
||||
|
||||
File diff suppressed because one or more lines are too long
93
predictions/k_nearest.py
Normal file
93
predictions/k_nearest.py
Normal file
@ -0,0 +1,93 @@
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
|
||||
# TODO scaling van distance tov intra distances?
|
||||
# TODO efficientere manier om k=1 te doen
|
||||
|
||||
|
||||
def minkowski_distance_p(x, y, p=2):
|
||||
x = np.asarray(x)
|
||||
y = np.asarray(y)
|
||||
|
||||
# Find the smallest common datatype with float64 (return type of this
|
||||
# function) - addresses #10262.
|
||||
# Don't just cast to float64 for complex input case.
|
||||
common_datatype = np.promote_types(np.promote_types(x.dtype, y.dtype),
|
||||
'float64')
|
||||
|
||||
# Make sure x and y are NumPy arrays of correct datatype.
|
||||
x = x.astype(common_datatype)
|
||||
y = y.astype(common_datatype)
|
||||
|
||||
if p == np.inf:
|
||||
return np.amax(np.abs(y - x), axis=-1)
|
||||
elif p == 1:
|
||||
return np.sum(np.abs(y - x), axis=-1)
|
||||
else:
|
||||
return np.sum(np.abs(y - x) ** p, axis=-1)
|
||||
|
||||
|
||||
def minkowski_distance(x, y, p=2):
|
||||
x = np.asarray(x)
|
||||
y = np.asarray(y)
|
||||
if p == np.inf or p == 1:
|
||||
return minkowski_distance_p(x, y, p)
|
||||
else:
|
||||
return minkowski_distance_p(x, y, p) ** (1. / p)
|
||||
|
||||
|
||||
class KNearestNeighbours:
|
||||
def __init__(self, k=5):
|
||||
self.k = k
|
||||
self.embeddings = None
|
||||
self.embeddings_list = None
|
||||
|
||||
def set_embeddings(self, embeddings):
|
||||
self.embeddings = embeddings
|
||||
df = embeddings.drop(columns=['labels', 'label_name', 'embeddings'])
|
||||
# convert embedding from string to list of floats
|
||||
df["embeddings"] = df["embeddings2"].apply(lambda x: [float(i) for i in x[1:-1].split(", ")])
|
||||
# drop embeddings2
|
||||
df = df.drop(columns=['embeddings2'])
|
||||
# to list
|
||||
self.embeddings_list = df["embeddings"].tolist()
|
||||
|
||||
def distance_matrix(self, keypoints, p=2, threshold=1000000):
|
||||
x = np.array(keypoints)
|
||||
m, k = x.shape
|
||||
y = np.asarray(self.embeddings_list)
|
||||
n, kk = y.shape
|
||||
|
||||
if k != kk:
|
||||
raise ValueError(f"x contains {k}-dimensional vectors but y contains "
|
||||
f"{kk}-dimensional vectors")
|
||||
|
||||
if m * n * k <= threshold:
|
||||
# print("Using minkowski_distance")
|
||||
return minkowski_distance(x[:, np.newaxis, :], y[np.newaxis, :, :], p)
|
||||
else:
|
||||
result = np.empty((m, n), dtype=float) # FIXME: figure out the best dtype
|
||||
if m < n:
|
||||
for i in range(m):
|
||||
result[i, :] = minkowski_distance(x[i], y, p)
|
||||
else:
|
||||
for j in range(n):
|
||||
result[:, j] = minkowski_distance(x, y[j], p)
|
||||
return result
|
||||
|
||||
def predict(self, key_points_embeddings):
|
||||
# calculate distance matrix
|
||||
dist_matrix = self.distance_matrix(key_points_embeddings, p=2, threshold=1000000)
|
||||
|
||||
# get the 5 closest matches and select the class that is most common and use the average distance as the score
|
||||
# get the 5 closest matches
|
||||
indeces = np.argsort(dist_matrix)[0][:self.k]
|
||||
# get the labels
|
||||
labels = self.embeddings["label_name"].iloc[indeces].tolist()
|
||||
c = Counter(labels).most_common()[0][0]
|
||||
|
||||
# filter indeces to only include the most common label
|
||||
indeces = [i for i in indeces if self.embeddings["label_name"].iloc[i] == c]
|
||||
# get the average distance
|
||||
score = np.mean(dist_matrix[0][indeces])
|
||||
return c, score
|
||||
86
predictions/plotting.py
Normal file
86
predictions/plotting.py
Normal file
@ -0,0 +1,86 @@
|
||||
import json
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def load_results():
|
||||
with open("predictions/test_results/knn.json", 'r') as f:
|
||||
results = json.load(f)
|
||||
return results
|
||||
|
||||
def plot_all():
|
||||
results = load_results()
|
||||
print(f"average elapsed time to detect a sign: {get_general_elapsed_time(results)}")
|
||||
plot_general_accuracy(results)
|
||||
for label in results.keys():
|
||||
plot_accuracy_per_label(results, label)
|
||||
|
||||
|
||||
def general_accuracy(results):
|
||||
label_accuracy = get_label_accuracy(results)
|
||||
accuracy = []
|
||||
amount = []
|
||||
response = []
|
||||
for label in label_accuracy.keys():
|
||||
for index, value in enumerate(label_accuracy[label]):
|
||||
if index >= len(accuracy):
|
||||
accuracy.append(0)
|
||||
amount.append(0)
|
||||
accuracy[index] += label_accuracy[label][index]
|
||||
amount[index] += 1
|
||||
for a, b in zip(accuracy, amount):
|
||||
if b < 5:
|
||||
break
|
||||
response.append(a / b)
|
||||
return response
|
||||
def plot_general_accuracy(results):
|
||||
accuracy = general_accuracy(results)
|
||||
plt.plot(accuracy)
|
||||
plt.title = "General accuracy"
|
||||
plt.ylabel('accuracy')
|
||||
plt.xlabel('buffer')
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_accuracy_per_label(results, label):
|
||||
accuracy = get_label_accuracy(results)
|
||||
plt.plot(accuracy[label], label=label)
|
||||
plt.titel = f"Accuracy per label {label}"
|
||||
plt.ylabel('accuracy')
|
||||
plt.xlabel('prediction')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
def get_label_accuracy(results):
|
||||
accuracy = {}
|
||||
amount = {}
|
||||
response = {}
|
||||
for label, predictions in results.items():
|
||||
if label not in accuracy:
|
||||
accuracy[label] = []
|
||||
amount[label] = []
|
||||
for prediction in predictions:
|
||||
for index, value in enumerate(prediction["predictions"]):
|
||||
if index >= len(accuracy[label]):
|
||||
accuracy[label].append(0)
|
||||
amount[label].append(0)
|
||||
accuracy[label][index] += 1 if value["correct"] else 0
|
||||
amount[label][index] += 1
|
||||
for label in accuracy:
|
||||
response[label] = []
|
||||
for index, value in enumerate(accuracy[label]):
|
||||
if amount[label][index] < 2:
|
||||
break
|
||||
response[label].append(accuracy[label][index] / amount[label][index])
|
||||
return response
|
||||
|
||||
def get_general_elapsed_time(results):
|
||||
label_time = get_label_elapsed_time(results)
|
||||
return sum([label_time[label] for label in results]) / len(results)
|
||||
|
||||
def get_label_elapsed_time(results):
|
||||
return {label: sum([result["elapsed_time"] for result in results[label]]) / len(results[label]) for label in results}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plot_all()
|
||||
267
predictions/predictor.py
Normal file
267
predictions/predictor.py
Normal file
@ -0,0 +1,267 @@
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from predictions.k_nearest import KNearestNeighbours
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
from models import SPOTER_EMBEDDINGS
|
||||
|
||||
BODY_IDENTIFIERS = [
|
||||
0,
|
||||
33,
|
||||
5,
|
||||
2,
|
||||
8,
|
||||
7,
|
||||
12,
|
||||
11,
|
||||
14,
|
||||
13,
|
||||
16,
|
||||
15,
|
||||
]
|
||||
|
||||
HAND_IDENTIFIERS = [
|
||||
0,
|
||||
8,
|
||||
7,
|
||||
6,
|
||||
5,
|
||||
12,
|
||||
11,
|
||||
10,
|
||||
9,
|
||||
16,
|
||||
15,
|
||||
14,
|
||||
13,
|
||||
20,
|
||||
19,
|
||||
18,
|
||||
17,
|
||||
4,
|
||||
3,
|
||||
2,
|
||||
1,
|
||||
]
|
||||
|
||||
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
|
||||
|
||||
|
||||
class Predictor:
|
||||
def __init__(self, embeddings_path, predictor_type):
|
||||
|
||||
# Initialize MediaPipe Hands model
|
||||
self.holistic = mp.solutions.holistic.Holistic(
|
||||
min_detection_confidence=0.5,
|
||||
min_tracking_confidence=0.5,
|
||||
model_complexity=2
|
||||
)
|
||||
|
||||
self.mp_holistic = mp.solutions.holistic
|
||||
self.mp_drawing = mp.solutions.drawing_utils
|
||||
# buffer = []
|
||||
self.left_shoulder_index = 11
|
||||
self.right_shoulder_index = 12
|
||||
self.neck_index = 33
|
||||
self.nose_index = 0
|
||||
self.left_eye_index = 2
|
||||
|
||||
# load training embedding csv
|
||||
self.embeddings = pd.read_csv(embeddings_path)
|
||||
|
||||
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
||||
|
||||
self.model = SPOTER_EMBEDDINGS(
|
||||
features=checkpoint["config_args"].vector_length,
|
||||
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||
).to(device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["state_dict"])
|
||||
|
||||
if predictor_type is None:
|
||||
self.predictor = KNearestNeighbours(1)
|
||||
else:
|
||||
self.predictor = predictor_type
|
||||
self.predictor.set_embeddings(self.embeddings)
|
||||
|
||||
def extract_keypoints(self, image_orig):
|
||||
image = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
|
||||
results = self.holistic.process(image)
|
||||
|
||||
def extract_keypoints(lmks):
|
||||
if lmks:
|
||||
a = np.array([[float(lmk.x), float(lmk.y)] for lmk in lmks.landmark])
|
||||
return a
|
||||
return None
|
||||
|
||||
def calculate_neck(keypoints):
|
||||
if keypoints is not None:
|
||||
left_shoulder = keypoints[11]
|
||||
right_shoulder = keypoints[12]
|
||||
|
||||
neck = [(float(left_shoulder[0]) + float(right_shoulder[0])) / 2,
|
||||
(float(left_shoulder[1]) + float(right_shoulder[1])) / 2]
|
||||
# add neck to keypoints
|
||||
keypoints = np.append(keypoints, [neck], axis=0)
|
||||
return keypoints
|
||||
return None
|
||||
|
||||
pose = extract_keypoints(results.pose_landmarks)
|
||||
pose = calculate_neck(pose)
|
||||
if pose is None:
|
||||
return None
|
||||
pose_norm = self.normalize_pose(pose)
|
||||
# filter out keypoints that are not in BODY_IDENTIFIERS and make sure they are in the correct order
|
||||
pose_norm = pose_norm[BODY_IDENTIFIERS]
|
||||
|
||||
left_hand = extract_keypoints(results.left_hand_landmarks)
|
||||
right_hand = extract_keypoints(results.right_hand_landmarks)
|
||||
|
||||
if left_hand is None and right_hand is None:
|
||||
return None
|
||||
|
||||
# normalize hands
|
||||
if left_hand is not None:
|
||||
left_hand = self.normalize_hand(left_hand)
|
||||
else:
|
||||
left_hand = np.zeros((21, 2))
|
||||
if right_hand is not None:
|
||||
right_hand = self.normalize_hand(right_hand)
|
||||
else:
|
||||
right_hand = np.zeros((21, 2))
|
||||
|
||||
left_hand = left_hand[HAND_IDENTIFIERS]
|
||||
|
||||
right_hand = right_hand[HAND_IDENTIFIERS]
|
||||
|
||||
# combine pose and hands
|
||||
pose_norm = np.append(pose_norm, left_hand, axis=0)
|
||||
pose_norm = np.append(pose_norm, right_hand, axis=0)
|
||||
|
||||
# move interval
|
||||
pose_norm -= 0.5
|
||||
|
||||
return pose_norm
|
||||
|
||||
# if we have the keypoints, normalize single body, keypoints is numpy array of (identifiers, 2)
|
||||
def normalize_pose(self, keypoints):
|
||||
left_shoulder = keypoints[self.left_shoulder_index]
|
||||
right_shoulder = keypoints[self.right_shoulder_index]
|
||||
|
||||
neck = keypoints[self.neck_index]
|
||||
nose = keypoints[self.nose_index]
|
||||
|
||||
# Prevent from even starting the analysis if some necessary elements are not present
|
||||
if (left_shoulder[0] == 0 or right_shoulder[0] == 0
|
||||
or (left_shoulder[0] == right_shoulder[0] and left_shoulder[1] == right_shoulder[1])) and (
|
||||
neck[0] == 0 or nose[0] == 0 or (neck[0] == nose[0] and neck[1] == nose[1])):
|
||||
return keypoints
|
||||
|
||||
if left_shoulder[0] != 0 and right_shoulder[0] != 0 and (
|
||||
left_shoulder[0] != right_shoulder[0] or left_shoulder[1] != right_shoulder[1]):
|
||||
shoulder_distance = ((((left_shoulder[0] - right_shoulder[0]) ** 2) + (
|
||||
(left_shoulder[1] - right_shoulder[1]) ** 2)) ** 0.5)
|
||||
head_metric = shoulder_distance
|
||||
else:
|
||||
neck_nose_distance = ((((neck[0] - nose[0]) ** 2) + ((neck[1] - nose[1]) ** 2)) ** 0.5)
|
||||
head_metric = neck_nose_distance
|
||||
|
||||
# Set the starting and ending point of the normalization bounding box
|
||||
starting_point = [keypoints[self.neck_index][0] - 3 * head_metric,
|
||||
keypoints[self.left_eye_index][1] + head_metric]
|
||||
ending_point = [keypoints[self.neck_index][0] + 3 * head_metric, starting_point[1] - 6 * head_metric]
|
||||
|
||||
if starting_point[0] < 0:
|
||||
starting_point[0] = 0
|
||||
if starting_point[1] < 0:
|
||||
starting_point[1] = 0
|
||||
if ending_point[0] < 0:
|
||||
ending_point[0] = 0
|
||||
if ending_point[1] < 0:
|
||||
ending_point[1] = 0
|
||||
|
||||
# Normalize the keypoints
|
||||
for i in range(len(keypoints)):
|
||||
keypoints[i][0] = (keypoints[i][0] - starting_point[0]) / (ending_point[0] - starting_point[0])
|
||||
keypoints[i][1] = (keypoints[i][1] - ending_point[1]) / (starting_point[1] - ending_point[1])
|
||||
|
||||
return keypoints
|
||||
|
||||
def normalize_hand(self, keypoints):
|
||||
x_values = [keypoints[i][0] for i in range(len(keypoints)) if keypoints[i][0] != 0]
|
||||
y_values = [keypoints[i][1] for i in range(len(keypoints)) if keypoints[i][1] != 0]
|
||||
|
||||
if not x_values or not y_values:
|
||||
return keypoints
|
||||
|
||||
width, height = max(x_values) - min(x_values), max(y_values) - min(y_values)
|
||||
if width > height:
|
||||
delta_x = 0.1 * width
|
||||
delta_y = delta_x + ((width - height) / 2)
|
||||
else:
|
||||
delta_y = 0.1 * height
|
||||
delta_x = delta_y + ((height - width) / 2)
|
||||
|
||||
starting_point = (min(x_values) - delta_x, min(y_values) - delta_y)
|
||||
ending_point = (max(x_values) + delta_x, max(y_values) + delta_y)
|
||||
|
||||
if ending_point[0] - starting_point[0] == 0 or ending_point[1] - starting_point[1] == 0:
|
||||
return keypoints
|
||||
|
||||
# normalize keypoints
|
||||
for i in range(len(keypoints)):
|
||||
keypoints[i][0] = (keypoints[i][0] - starting_point[0]) / (ending_point[0] - starting_point[0])
|
||||
keypoints[i][1] = (keypoints[i][1] - starting_point[1]) / (ending_point[1] - starting_point[1])
|
||||
|
||||
return keypoints
|
||||
|
||||
|
||||
def get_embedding(self, keypoints):
|
||||
# run model on frame
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
keypoints = torch.from_numpy(np.array([keypoints])).float().to(device)
|
||||
new_embeddings = self.model(keypoints).cpu().numpy().tolist()[0]
|
||||
return new_embeddings
|
||||
|
||||
def predict(self, embeddings):
|
||||
return self.predictor.predict(embeddings)
|
||||
|
||||
def make_prediction(self, keypoints):
|
||||
# run model on frame
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
keypoints = torch.from_numpy(np.array([keypoints])).float().to(device)
|
||||
new_embeddings = self.model(keypoints).cpu().numpy().tolist()[0]
|
||||
|
||||
return self.predictor.predict(new_embeddings)
|
||||
|
||||
def validation(self):
|
||||
# load validation data
|
||||
validation_data = np.load('validation_data.npy', allow_pickle=True)
|
||||
validation_labels = np.load('validation_labels.npy', allow_pickle=True)
|
||||
|
||||
# run model on validation data
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
validation_embeddings = self.model(torch.from_numpy(validation_data).float().to(device)).cpu().numpy()
|
||||
|
||||
# predict validation data
|
||||
predictions = self.predictor.predict(validation_embeddings)
|
||||
|
||||
# calculate accuracy
|
||||
correct = 0
|
||||
for i in range(len(predictions)):
|
||||
if predictions[i] == validation_labels[i]:
|
||||
correct += 1
|
||||
accuracy = correct / len(predictions)
|
||||
print('Accuracy: ' + str(accuracy))
|
||||
|
||||
|
||||
34
predictions/svm_model.py
Normal file
34
predictions/svm_model.py
Normal file
@ -0,0 +1,34 @@
|
||||
from sklearn import svm
|
||||
|
||||
class SVM:
|
||||
def __init__(self, type="ovo"):
|
||||
self.label_name_to_label = None
|
||||
self.clf = None
|
||||
self.embeddings_list = None
|
||||
self.labels = None
|
||||
self.type = type
|
||||
|
||||
def set_embeddings(self, embeddings):
|
||||
# convert embedding from string to list of floats
|
||||
embeddings["embeddings"] = embeddings["embeddings2"].apply(lambda x: [float(i) for i in x[1:-1].split(", ")])
|
||||
# drop embeddings2
|
||||
df = embeddings.drop(columns=['embeddings2'])
|
||||
# to list
|
||||
self.embeddings_list = df["embeddings"].tolist()
|
||||
self.labels = df["labels"].tolist()
|
||||
self.label_name_to_label = df[["label_name", "labels"]]
|
||||
self.label_name_to_label.columns = ["label_name", "label"]
|
||||
self.label_name_to_label = self.label_name_to_label.drop_duplicates()
|
||||
|
||||
self.train()
|
||||
|
||||
def train(self):
|
||||
self.clf = svm.SVC(decision_function_shape=self.type, probability=True)
|
||||
self.clf.fit(self.embeddings_list, self.labels)
|
||||
|
||||
def predict(self, key_points_embeddings):
|
||||
label = self.clf.predict(key_points_embeddings)
|
||||
score = self.clf.predict_log_proba(key_points_embeddings)
|
||||
# TODO fix dictionary
|
||||
label = label.item()
|
||||
return self.label_name_to_label.loc[self.label_name_to_label["label"] == label]["label_name"].iloc[0], score[0][label]
|
||||
1
predictions/test_results/knn.json
Normal file
1
predictions/test_results/knn.json
Normal file
File diff suppressed because one or more lines are too long
137
predictions/validation.py
Normal file
137
predictions/validation.py
Normal file
@ -0,0 +1,137 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from predictions.k_nearest import KNearestNeighbours
|
||||
from predictions.predictor import Predictor
|
||||
from predictions.svm_model import SVM
|
||||
|
||||
buffer_size = 15
|
||||
|
||||
|
||||
def predict_video(predictor, path_video):
|
||||
# open mp4 video
|
||||
cap = cv2.VideoCapture(path_video)
|
||||
buffer = []
|
||||
ret, img = cap.read() # read one frame from the 'capture' object; img is (H, W, C)
|
||||
desired_fps = 15
|
||||
original_fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
print("Original FPS: ", original_fps)
|
||||
# Calculate the frame skipping rate based on desired frame rate
|
||||
frame_skip = original_fps // desired_fps
|
||||
if frame_skip == 0:
|
||||
frame_skip = 1
|
||||
print("Frame skip: ", frame_skip)
|
||||
frame_number = 0
|
||||
while img is not None:
|
||||
pose = predictor.extract_keypoints(img)
|
||||
if pose is not None and frame_number % frame_skip == 0:
|
||||
buffer.append(pose)
|
||||
frame_number += 1
|
||||
ret, img = cap.read() # read one frame from the 'capture' object; img is (H, W, C)
|
||||
print(len(buffer))
|
||||
return buffer
|
||||
|
||||
|
||||
def get_embeddings(predictor, buffer, name):
|
||||
# check if file exists with name
|
||||
# if os.path.exists("predictions/test_embeddings/" + name + ".csv"):
|
||||
# print("Loading embeddings from file")
|
||||
# # load embeddings from file
|
||||
# with open("predictions/test_embeddings/" + name + ".csv", 'r') as f:
|
||||
# embeddings = json.load(f)
|
||||
# else:
|
||||
embeddings = []
|
||||
for index in range(buffer_size, len(buffer)):
|
||||
embedding = predictor.get_embedding(buffer[index - buffer_size:index])
|
||||
embeddings.append(embedding)
|
||||
with open("predictions/test_embeddings/" + name + ".csv", 'w') as f:
|
||||
json.dump(embeddings, f)
|
||||
return embeddings
|
||||
|
||||
|
||||
def compare_embeddings(predictor, embeddings, label_video, ):
|
||||
results = []
|
||||
for embedding in embeddings:
|
||||
label, score = predictor.predict(embedding)
|
||||
|
||||
results.append({"label": label, "score": score, "label_video": label_video, "correct": label == label_video})
|
||||
return results
|
||||
|
||||
|
||||
def predict_video_files(predictor, path_video, label_video):
|
||||
buffer = predict_video(predictor, path_video)
|
||||
embeddings = get_embeddings(predictor, buffer, path_video.split("/")[-1].split(".")[0])
|
||||
return compare_embeddings(predictor, embeddings, label_video)
|
||||
|
||||
|
||||
def get_test_data(data_folder):
|
||||
files = np.array([data_folder + f for f in os.listdir(data_folder) if f.endswith(".mp4")])
|
||||
train_test = [f.split("/")[-1].split("!")[1] for f in files]
|
||||
test_files = files[np.array(train_test) == "test"]
|
||||
test_labels = [f.split("/")[-1].split("!")[0] for f in test_files]
|
||||
|
||||
return test_files, test_labels
|
||||
|
||||
|
||||
def test_data(predictor, data_folder):
|
||||
results = {}
|
||||
for path_video, label_video in zip(*get_test_data(data_folder)):
|
||||
print(path_video, label_video)
|
||||
start_time = time.time()
|
||||
prediction = predict_video_files(predictor, path_video, label_video)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
# divide elapsed time by amount of predictions made so it represents an avarage execution time
|
||||
if len(prediction) > 0:
|
||||
elapsed_time /= len(prediction)
|
||||
if label_video not in results:
|
||||
results[label_video] = []
|
||||
results[label_video].append({"predictions": prediction, "elapsed_time": elapsed_time, "video": path_video})
|
||||
|
||||
print("DONE")
|
||||
return results
|
||||
|
||||
|
||||
def plot_general_accuracy(results):
|
||||
accuracy = []
|
||||
amount = []
|
||||
for result in results:
|
||||
for index, value in enumerate(result[0]):
|
||||
if len(accuracy) <= index:
|
||||
accuracy.append(0)
|
||||
amount.append(0)
|
||||
accuracy[index] += 1 if value["correct"] else 0
|
||||
amount[index] += 1
|
||||
# plot the general accuracy
|
||||
plt.plot(accuracy)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
type_predictor = "knn"
|
||||
if type_predictor == "knn":
|
||||
k = 1
|
||||
predictor_type = KNearestNeighbours(k)
|
||||
elif type_predictor == "svm":
|
||||
predictor_type = SVM()
|
||||
else:
|
||||
predictor_type = KNearestNeighbours(1)
|
||||
|
||||
# embeddings_path = 'embeddings/basic-signs/embeddings.csv'
|
||||
embeddings_path = 'embeddings/fingerspelling/embeddings.csv'
|
||||
|
||||
predictor = Predictor(embeddings_path, predictor_type)
|
||||
|
||||
data_folder = '/home/tibe/Projects/design_project/sign-predictor/data/fingerspelling/data/'
|
||||
results = test_data(predictor, data_folder)
|
||||
# write results to a results json file
|
||||
with open("predictions/test_results/" + type_predictor + ".json", 'w') as f:
|
||||
json.dump(results, f)
|
||||
print(results)
|
||||
# plot_general_accuracy(results)
|
||||
@ -1,5 +1,5 @@
|
||||
from argparse import ArgumentParser
|
||||
from preprocessing.create_wlasl_landmarks_dataset import parse_create_args, create
|
||||
from preprocessing.create_fingerspelling_dataset import parse_create_args, create
|
||||
from preprocessing.extract_mediapipe_landmarks import parse_extract_args, extract
|
||||
|
||||
|
||||
|
||||
174
preprocessing/create_fingerspelling_dataset.py
Normal file
174
preprocessing/create_fingerspelling_dataset.py
Normal file
@ -0,0 +1,174 @@
|
||||
import os
|
||||
import os.path as op
|
||||
import json
|
||||
import shutil
|
||||
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from utils import get_logger
|
||||
from tqdm.auto import tqdm
|
||||
from sklearn.model_selection import train_test_split
|
||||
from normalization.blazepose_mapping import map_blazepose_df
|
||||
|
||||
BASE_DATA_FOLDER = 'data/'
|
||||
|
||||
mp_drawing = mp.solutions.drawing_utils
|
||||
mp_drawing_styles = mp.solutions.drawing_styles
|
||||
mp_hands = mp.solutions.hands
|
||||
mp_holistic = mp.solutions.holistic
|
||||
pose_landmarks = mp_holistic.PoseLandmark
|
||||
hand_landmarks = mp_holistic.HandLandmark
|
||||
|
||||
|
||||
def get_landmarks_names():
|
||||
'''
|
||||
Returns landmark names for mediapipe holistic model
|
||||
'''
|
||||
pose_lmks = ','.join([f'{lmk.name.lower()}_x,{lmk.name.lower()}_y' for lmk in pose_landmarks])
|
||||
left_hand_lmks = ','.join([f'left_hand_{lmk.name.lower()}_x,left_hand_{lmk.name.lower()}_y'
|
||||
for lmk in hand_landmarks])
|
||||
right_hand_lmks = ','.join([f'right_hand_{lmk.name.lower()}_x,right_hand_{lmk.name.lower()}_y'
|
||||
for lmk in hand_landmarks])
|
||||
lmks_names = f'{pose_lmks},{left_hand_lmks},{right_hand_lmks}'
|
||||
return lmks_names
|
||||
|
||||
|
||||
def convert_to_str(arr, precision=6):
|
||||
if isinstance(arr, np.ndarray):
|
||||
values = []
|
||||
for val in arr:
|
||||
if val == 0:
|
||||
values.append('0')
|
||||
else:
|
||||
values.append(f'{val:.{precision}f}')
|
||||
return f"[{','.join(values)}]"
|
||||
else:
|
||||
return str(arr)
|
||||
|
||||
|
||||
def parse_create_args(parser):
|
||||
parser.add_argument('--landmarks-dataset', '-lmks', required=True,
|
||||
help='Path to folder with landmarks npy files. \
|
||||
You need to run `extract_mediapipe_landmarks.py` script first')
|
||||
parser.add_argument('--dataset-folder', '-df', default='data/wlasl',
|
||||
help='Path to folder where original `WLASL_v0.3.json` and `id_to_label.json` are stored. \
|
||||
Note that final CSV files will be saved in this folder too.')
|
||||
parser.add_argument('--videos-folder', '-videos', default=None,
|
||||
help='Path to folder with videos. If None, then no information of videos (fps, length, \
|
||||
width and height) will be stored in final csv file')
|
||||
parser.add_argument('--num-classes', '-nc', default=100, type=int, help='Number of classes to use in WLASL dataset')
|
||||
parser.add_argument('--create-new-split', action='store_true')
|
||||
parser.add_argument('--test-size', '-ts', default=0.25, type=float,
|
||||
help='Test split percentage size. Only required if --create-new-split is set')
|
||||
|
||||
|
||||
# python3 preprocessing.py --landmarks-dataset=data/landmarks -videos data/wlasl/videos
|
||||
def create(args):
|
||||
logger = get_logger(__name__)
|
||||
|
||||
landmarks_dataset = args.landmarks_dataset
|
||||
videos_folder = args.videos_folder
|
||||
dataset_folder = args.dataset_folder
|
||||
num_classes = args.num_classes
|
||||
test_size = args.test_size
|
||||
|
||||
os.makedirs(dataset_folder, exist_ok=True)
|
||||
|
||||
# shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/id_to_label.json'), dataset_folder)
|
||||
# shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/WLASL_v0.3.json'), dataset_folder)
|
||||
|
||||
# get files in landmarks_dataset folder
|
||||
landmarks_files = os.listdir(landmarks_dataset)
|
||||
|
||||
video_data = []
|
||||
for i, file in enumerate(tqdm(landmarks_files)):
|
||||
|
||||
# split by !
|
||||
label = file.split('!')[0]
|
||||
subset = file.split('!')[1].split('.')[0]
|
||||
|
||||
# remove npy and set mp4
|
||||
video_id = file.replace('.npy', "")
|
||||
|
||||
|
||||
video_dict = {'video_id': video_id,
|
||||
'label_name': label,
|
||||
'split': subset}
|
||||
|
||||
if videos_folder is not None:
|
||||
cap = cv2.VideoCapture(op.join(videos_folder, f'{video_id}.mp4'))
|
||||
if not cap.isOpened():
|
||||
logger.warning(f'Video {video_id}.mp4 not found')
|
||||
continue
|
||||
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
length = cap.get(cv2.CAP_PROP_FRAME_COUNT) / float(cap.get(cv2.CAP_PROP_FPS))
|
||||
video_info = {'video_width': width,
|
||||
'video_height': height,
|
||||
'fps': fps,
|
||||
'length': length}
|
||||
video_dict.update(video_info)
|
||||
video_data.append(video_dict)
|
||||
|
||||
df_video = pd.DataFrame(video_data)
|
||||
video_ids = df_video['video_id'].unique()
|
||||
lmks_data = []
|
||||
lmks_names = get_landmarks_names().split(',')
|
||||
|
||||
# get labels from df_video
|
||||
labels = df_video['label_name'].unique()
|
||||
# map labels to ids
|
||||
label_to_id = {label: i for i, label in enumerate(labels)}
|
||||
|
||||
# add label_id column to df_video
|
||||
df_video['labels'] = df_video['label_name'].map(label_to_id)
|
||||
|
||||
# export to json file as id to label
|
||||
id_to_label = {i: label for label, i in label_to_id.items()}
|
||||
with open(op.join(dataset_folder, 'id_to_label.json'), 'w') as f:
|
||||
json.dump(id_to_label, f, indent=4)
|
||||
|
||||
for video_id in video_ids:
|
||||
lmk_fn = op.join(landmarks_dataset, f'{video_id}.npy')
|
||||
if not op.exists(lmk_fn):
|
||||
logger.warning(f'{lmk_fn} file not found. Skipping')
|
||||
continue
|
||||
lmk = np.load(lmk_fn).T
|
||||
lmks_dict = {'video_id': video_id}
|
||||
for lmk_, name in zip(lmk, lmks_names):
|
||||
lmks_dict[name] = lmk_
|
||||
lmks_data.append(lmks_dict)
|
||||
|
||||
df_lmks = pd.DataFrame(lmks_data)
|
||||
df = pd.merge(df_video, df_lmks)
|
||||
aux_columns = ['split', 'video_id', 'labels', 'label_name']
|
||||
if videos_folder is not None:
|
||||
aux_columns += ['video_width', 'video_height', 'fps', 'length']
|
||||
df_aux = df[aux_columns]
|
||||
df = map_blazepose_df(df)
|
||||
df = pd.concat([df, df_aux], axis=1)
|
||||
if args.create_new_split:
|
||||
df_train, df_test = train_test_split(df, test_size=test_size, stratify=df['labels'], random_state=42)
|
||||
|
||||
|
||||
print(f'Num classes: {num_classes}')
|
||||
print(df_train['labels'].value_counts())
|
||||
print(df_test['labels'].value_counts())
|
||||
assert set(df_train['labels'].unique()) == set(df_test['labels'].unique(
|
||||
)), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \
|
||||
the --create-new-split flag'
|
||||
for split, df_split in zip(['train', 'val'],
|
||||
[df_train, df_test]):
|
||||
fn_out = op.join(dataset_folder, f'{split}.csv')
|
||||
(df_split.reset_index(drop=True)
|
||||
.applymap(convert_to_str)
|
||||
.to_csv(fn_out, index=False))
|
||||
|
||||
else:
|
||||
fn_out = op.join(dataset_folder, 'train.csv')
|
||||
(df.reset_index(drop=True)
|
||||
.applymap(convert_to_str)
|
||||
.to_csv(fn_out, index=False))
|
||||
@ -4,6 +4,8 @@ import pandas as pd
|
||||
from tqdm.auto import tqdm
|
||||
import json
|
||||
|
||||
from normalization.blazepose_mapping import map_blazepose_df
|
||||
|
||||
def create(train_landmark_files, train_csv, dataset_folder, test_size):
|
||||
os.makedirs(dataset_folder, exist_ok=True)
|
||||
|
||||
@ -17,15 +19,15 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
|
||||
mapping = {
|
||||
'pose_0': 'nose',
|
||||
'pose_1': 'leftEye',
|
||||
'pose_2': 'rightEye',
|
||||
'pose_3': 'leftEar',
|
||||
'pose_4': 'rightEar',
|
||||
'pose_5': 'leftShoulder',
|
||||
'pose_6': 'rightShoulder',
|
||||
'pose_7': 'leftElbow',
|
||||
'pose_8': 'rightElbow',
|
||||
'pose_9': 'leftWrist',
|
||||
'pose_10': 'rightWrist',
|
||||
'pose_4': 'rightEye',
|
||||
'pose_7': 'leftEar',
|
||||
'pose_8': 'rightEar',
|
||||
'pose_11': 'leftShoulder',
|
||||
'pose_12': 'rightShoulder',
|
||||
'pose_13': 'leftElbow',
|
||||
'pose_14': 'rightElbow',
|
||||
'pose_15': 'leftWrist',
|
||||
'pose_16': 'rightWrist',
|
||||
|
||||
'left_hand_0': 'wrist_left',
|
||||
'left_hand_1': 'thumbCMC_left',
|
||||
@ -77,7 +79,7 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
|
||||
columns.append(f'{v}_X')
|
||||
columns.append(f'{v}_Y')
|
||||
|
||||
for _, row in tqdm(train_df.head(6000).iterrows(), total=6000):
|
||||
for _, row in tqdm(train_df.head(10000).iterrows(), total=10000):
|
||||
path, participant_id, sequence_id, sign = row['path'], row['participant_id'], row['sequence_id'], row['sign']
|
||||
parquet_file = os.path.join(train_landmark_files, str(participant_id), f"{sequence_id}.parquet")
|
||||
|
||||
@ -136,6 +138,7 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
|
||||
video_data.append(new_landmark_data)
|
||||
|
||||
video_data = pd.concat(video_data, axis=0, ignore_index=True)
|
||||
video_data = map_blazepose_df(video_data, rename=False)
|
||||
video_data.to_csv(os.path.join(dataset_folder, 'spoter.csv'), index=False)
|
||||
|
||||
train_landmark_files = 'data/train_landmark_files'
|
||||
|
||||
@ -110,6 +110,7 @@ def create(args):
|
||||
'length': length}
|
||||
video_dict.update(video_info)
|
||||
video_data.append(video_dict)
|
||||
|
||||
df_video = pd.DataFrame(video_data)
|
||||
video_ids = df_video['video_id'].unique()
|
||||
lmks_data = []
|
||||
@ -126,9 +127,7 @@ def create(args):
|
||||
lmks_data.append(lmks_dict)
|
||||
|
||||
df_lmks = pd.DataFrame(lmks_data)
|
||||
print(df_lmks)
|
||||
df = pd.merge(df_video, df_lmks)
|
||||
print(df)
|
||||
aux_columns = ['split', 'video_id', 'labels', 'label_name']
|
||||
if videos_folder is not None:
|
||||
aux_columns += ['video_width', 'video_height', 'fps', 'length']
|
||||
|
||||
@ -132,6 +132,12 @@ def extract(args):
|
||||
ret, image_orig = cap.read()
|
||||
height, width = image_orig.shape[:2]
|
||||
landmarks_video = []
|
||||
|
||||
# make sure fps is 20 by determining the number of frames to be skipped
|
||||
frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
frame_skip = (frame_rate // 20) - 1
|
||||
|
||||
|
||||
with tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))) as pbar:
|
||||
with mp_holistic.Holistic(
|
||||
static_image_mode=False,
|
||||
@ -145,6 +151,9 @@ def extract(args):
|
||||
print(e)
|
||||
landmarks = get_landmarks(image_orig, holistic, debug=True)
|
||||
ret, image_orig = cap.read()
|
||||
for _ in range(frame_skip):
|
||||
ret, image_orig = cap.read()
|
||||
pbar.update(1)
|
||||
landmarks_video.append(landmarks)
|
||||
pbar.update(1)
|
||||
landmarks_video = np.vstack(landmarks_video)
|
||||
|
||||
@ -8,7 +8,6 @@ dataset = "data/processed/spoter.csv"
|
||||
|
||||
# read the dataset
|
||||
df = pd.read_csv(dataset)
|
||||
df = map_blazepose_df(df)
|
||||
|
||||
with open("data/sign_to_prediction_index_map.json", "r") as f:
|
||||
sign_to_prediction_index_max = json.load(f)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
pandas
|
||||
bokeh==2.4.3
|
||||
boto3>=1.9
|
||||
clearml==1.6.4
|
||||
ipywidgets==8.0.4
|
||||
matplotlib==3.5.3
|
||||
mediapipe==0.8.11
|
||||
@ -9,6 +8,8 @@ notebook==6.5.2
|
||||
opencv-python==4.6.0.66
|
||||
plotly==5.11.0
|
||||
scikit-learn==1.0.2
|
||||
torch
|
||||
torchvision
|
||||
clearml==1.10.3
|
||||
torch==2.0.0
|
||||
torchvision==0.15.1
|
||||
tqdm==4.54.1
|
||||
optuna==3.1.1
|
||||
19
train.py
19
train.py
@ -15,7 +15,7 @@ from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from pathlib import Path
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from datasets import CzechSLRDataset, SLREmbeddingDataset, collate_fn_triplet_padd, collate_fn_padd
|
||||
from models import SPOTER, SPOTER_EMBEDDINGS, train_epoch, evaluate, train_epoch_embedding, \
|
||||
train_epoch_embedding_online, evaluate_embedding
|
||||
@ -32,7 +32,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
PROJECT_NAME = "spoter"
|
||||
PROJECT_NAME = "SpoterEmbedding"
|
||||
CLEARML = "clearml"
|
||||
|
||||
|
||||
@ -75,12 +75,25 @@ def train(args, tracker: Tracker):
|
||||
|
||||
# Construct the model
|
||||
if not args.classification_model:
|
||||
# if finetune, load the weights from the classification model
|
||||
if args.finetune:
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location=device)
|
||||
|
||||
slrt_model = SPOTER_EMBEDDINGS(
|
||||
features=checkpoint["config_args"].vector_length,
|
||||
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||
dropout=checkpoint["config_args"].dropout,
|
||||
)
|
||||
else:
|
||||
slrt_model = SPOTER_EMBEDDINGS(
|
||||
features=args.vector_length,
|
||||
hidden_dim=args.hidden_dim,
|
||||
norm_emb=args.normalize_embeddings,
|
||||
dropout=args.dropout
|
||||
dropout=args.dropout,
|
||||
)
|
||||
|
||||
|
||||
model_type = 'embed'
|
||||
if args.hard_triplet_mining == "None":
|
||||
cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2)
|
||||
|
||||
31
train.sh
31
train.sh
@ -1,22 +1,23 @@
|
||||
#!/bin/sh
|
||||
python -m train \
|
||||
--save_checkpoints_every 10 \
|
||||
--experiment_name "augment_rotate_75_x8" \
|
||||
--epochs 300 \
|
||||
python3 -m train \
|
||||
--save_checkpoints_every 1 \
|
||||
--experiment_name "Finetune Basic Signs" \
|
||||
--epochs 100 \
|
||||
--optimizer "ADAM" \
|
||||
--lr 0.001 \
|
||||
--lr 0.00001 \
|
||||
--batch_size 16 \
|
||||
--dataset_name "processed" \
|
||||
--training_set_path "spoter_train.csv" \
|
||||
--validation_set_path "spoter_test.csv" \
|
||||
--dataset_name "BasicSigns" \
|
||||
--training_set_path "train.csv" \
|
||||
--validation_set_path "val.csv" \
|
||||
--vector_length 32 \
|
||||
--epoch_iters -1 \
|
||||
--scheduler_factor 0 \
|
||||
--hard_triplet_mining "in_batch" \
|
||||
--scheduler_factor 0.05 \
|
||||
--hard_triplet_mining "None" \
|
||||
--filter_easy_triplets \
|
||||
--triplet_loss_margin 1 \
|
||||
--triplet_loss_margin 2 \
|
||||
--dropout 0.2 \
|
||||
--augmentations_prob=0.75 \
|
||||
--hard_mining_scheduler_triplets_threshold=0 \
|
||||
--normalize_embeddings \
|
||||
--num_classes 100 \
|
||||
--tracker=clearml \
|
||||
--dataset_loader=clearml \
|
||||
--dataset_project="SpoterEmbedding" \
|
||||
--finetune \
|
||||
--checkpoint_path "checkpoints/checkpoint_embed_3006.pth"
|
||||
@ -81,4 +81,7 @@ def get_default_args():
|
||||
help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \
|
||||
distance threshold of the batch sorter")
|
||||
|
||||
parser.add_argument("--finetune", action='store_true', default=False, help="Fintune the model")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="")
|
||||
|
||||
return parser
|
||||
|
||||
1636
visualize_data.ipynb
Normal file
1636
visualize_data.ipynb
Normal file
File diff suppressed because one or more lines are too long
334
webcam.py
334
webcam.py
@ -1,305 +1,39 @@
|
||||
|
||||
from collections import Counter
|
||||
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
from models import SPOTER_EMBEDDINGS
|
||||
from predictions.k_nearest import KNearestNeighbours
|
||||
from predictions.predictor import Predictor
|
||||
from predictions.svm_model import SVM
|
||||
|
||||
# Initialize MediaPipe Hands model
|
||||
holistic = mp.solutions.holistic.Holistic(
|
||||
min_detection_confidence=0.5,
|
||||
min_tracking_confidence=0.5,
|
||||
model_complexity=2
|
||||
)
|
||||
mp_holistic = mp.solutions.holistic
|
||||
mp_drawing = mp.solutions.drawing_utils
|
||||
if __name__ == '__main__':
|
||||
buffer = []
|
||||
# open webcam stream
|
||||
cap = cv2.VideoCapture(0)
|
||||
|
||||
BODY_IDENTIFIERS = [
|
||||
0,
|
||||
33,
|
||||
5,
|
||||
2,
|
||||
8,
|
||||
7,
|
||||
12,
|
||||
11,
|
||||
14,
|
||||
13,
|
||||
16,
|
||||
15,
|
||||
]
|
||||
|
||||
HAND_IDENTIFIERS = [
|
||||
0,
|
||||
8,
|
||||
7,
|
||||
6,
|
||||
5,
|
||||
12,
|
||||
11,
|
||||
10,
|
||||
9,
|
||||
16,
|
||||
15,
|
||||
14,
|
||||
13,
|
||||
20,
|
||||
19,
|
||||
18,
|
||||
17,
|
||||
4,
|
||||
3,
|
||||
2,
|
||||
1,
|
||||
]
|
||||
|
||||
def extract_keypoints(image_orig):
|
||||
image = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
|
||||
results = holistic.process(image)
|
||||
|
||||
def extract_keypoints(lmks):
|
||||
if lmks:
|
||||
a = np.array([[float(lmk.x), float(lmk.y)] for lmk in lmks.landmark])
|
||||
return a
|
||||
return None
|
||||
|
||||
def calculate_neck(keypoints):
|
||||
left_shoulder = keypoints[11]
|
||||
right_shoulder = keypoints[12]
|
||||
|
||||
neck = [(float(left_shoulder[0]) + float(right_shoulder[0])) / 2, (float(left_shoulder[1]) + float(right_shoulder[1])) / 2]
|
||||
# add neck to keypoints
|
||||
keypoints = np.append(keypoints, [neck], axis=0)
|
||||
return keypoints
|
||||
|
||||
pose = extract_keypoints(results.pose_landmarks)
|
||||
pose = calculate_neck(pose)
|
||||
pose_norm = normalize_pose(pose)
|
||||
# filter out keypoints that are not in BODY_IDENTIFIERS and make sure they are in the correct order
|
||||
pose_norm = pose_norm[BODY_IDENTIFIERS]
|
||||
|
||||
left_hand = extract_keypoints(results.left_hand_landmarks)
|
||||
right_hand = extract_keypoints(results.right_hand_landmarks)
|
||||
|
||||
if left_hand is None and right_hand is None:
|
||||
return None
|
||||
|
||||
# normalize hands
|
||||
if left_hand is not None:
|
||||
left_hand = normalize_hand(left_hand)
|
||||
type_predictor = "svm"
|
||||
if type_predictor == "knn":
|
||||
k = 10
|
||||
predictor_type = KNearestNeighbours(k)
|
||||
elif type_predictor == "svm":
|
||||
predictor_type = SVM()
|
||||
else:
|
||||
left_hand = np.zeros((21, 2))
|
||||
if right_hand is not None:
|
||||
right_hand = normalize_hand(right_hand)
|
||||
else:
|
||||
right_hand = np.zeros((21, 2))
|
||||
|
||||
left_hand = left_hand[HAND_IDENTIFIERS]
|
||||
|
||||
right_hand = right_hand[HAND_IDENTIFIERS]
|
||||
|
||||
# combine pose and hands
|
||||
pose_norm = np.append(pose_norm, left_hand, axis=0)
|
||||
pose_norm = np.append(pose_norm, right_hand, axis=0)
|
||||
|
||||
# move interval
|
||||
pose_norm -= 0.5
|
||||
|
||||
return pose_norm
|
||||
predictor_type = KNearestNeighbours(1)
|
||||
|
||||
|
||||
buffer = []
|
||||
|
||||
left_shoulder_index = 11
|
||||
right_shoulder_index = 12
|
||||
neck_index = 33
|
||||
nose_index = 0
|
||||
left_eye_index = 2
|
||||
# embeddings_path = 'embeddings/basic-signs/embeddings.csv'
|
||||
embeddings_path = 'embeddings/fingerspelling/embeddings.csv'
|
||||
|
||||
# if we have the keypoints, normalize single body, keypoints is numpy array of (identifiers, 2)
|
||||
def normalize_pose(keypoints):
|
||||
left_shoulder = keypoints[left_shoulder_index]
|
||||
right_shoulder = keypoints[right_shoulder_index]
|
||||
predictor = Predictor(embeddings_path, predictor_type)
|
||||
|
||||
neck = keypoints[neck_index]
|
||||
nose = keypoints[nose_index]
|
||||
index = 0
|
||||
|
||||
# Prevent from even starting the analysis if some necessary elements are not present
|
||||
if (left_shoulder[0] == 0 or right_shoulder[0] == 0
|
||||
or (left_shoulder[0] == right_shoulder[0] and left_shoulder[1] == right_shoulder[1])) and (
|
||||
neck[0] == 0 or nose[0] == 0 or (neck[0] == nose[0] and neck[1] == nose[1])):
|
||||
return keypoints
|
||||
while cap.isOpened():
|
||||
# Wait for key press to exit
|
||||
if cv2.waitKey(5) & 0xFF == 27:
|
||||
break
|
||||
|
||||
if left_shoulder[0] != 0 and right_shoulder[0] != 0 and (left_shoulder[0] != right_shoulder[0] or left_shoulder[1] != right_shoulder[1]):
|
||||
shoulder_distance = ((((left_shoulder[0] - right_shoulder[0]) ** 2) + ((left_shoulder[1] - right_shoulder[1]) ** 2)) ** 0.5)
|
||||
head_metric = shoulder_distance
|
||||
else:
|
||||
neck_nose_distance = ((((neck[0] - nose[0]) ** 2) + ((neck[1] - nose[1]) ** 2)) ** 0.5)
|
||||
head_metric = neck_nose_distance
|
||||
|
||||
# Set the starting and ending point of the normalization bounding box
|
||||
starting_point = [keypoints[neck_index][0] - 3 * head_metric, keypoints[left_eye_index][1] + head_metric]
|
||||
ending_point = [keypoints[neck_index][0] + 3 * head_metric, starting_point[1] - 6 * head_metric]
|
||||
|
||||
if starting_point[0] < 0:
|
||||
starting_point[0] = 0
|
||||
if starting_point[1] < 0:
|
||||
starting_point[1] = 0
|
||||
if ending_point[0] < 0:
|
||||
ending_point[0] = 0
|
||||
if ending_point[1] < 0:
|
||||
ending_point[1] = 0
|
||||
|
||||
# Normalize the keypoints
|
||||
for i in range(len(keypoints)):
|
||||
keypoints[i][0] = (keypoints[i][0] - starting_point[0]) / (ending_point[0] - starting_point[0])
|
||||
keypoints[i][1] = (keypoints[i][1] - ending_point[1]) / (starting_point[1] - ending_point[1])
|
||||
|
||||
return keypoints
|
||||
|
||||
def normalize_hand(keypoints):
|
||||
x_values = [keypoints[i][0] for i in range(len(keypoints)) if keypoints[i][0] != 0]
|
||||
y_values = [keypoints[i][1] for i in range(len(keypoints)) if keypoints[i][1] != 0]
|
||||
|
||||
if not x_values or not y_values:
|
||||
return keypoints
|
||||
|
||||
width, height = max(x_values) - min(x_values), max(y_values) - min(y_values)
|
||||
if width > height:
|
||||
delta_x = 0.1 * width
|
||||
delta_y = delta_x + ((width - height) / 2)
|
||||
else:
|
||||
delta_y = 0.1 * height
|
||||
delta_x = delta_y + ((height - width) / 2)
|
||||
|
||||
starting_point = (min(x_values) - delta_x, min(y_values) - delta_y)
|
||||
ending_point = (max(x_values) + delta_x, max(y_values) + delta_y)
|
||||
|
||||
if ending_point[0] - starting_point[0] == 0 or ending_point[1] - starting_point[1] == 0:
|
||||
return keypoints
|
||||
|
||||
# normalize keypoints
|
||||
for i in range(len(keypoints)):
|
||||
keypoints[i][0] = (keypoints[i][0] - starting_point[0]) / (ending_point[0] - starting_point[0])
|
||||
keypoints[i][1] = (keypoints[i][1] - starting_point[1]) / (ending_point[1] - starting_point[1])
|
||||
|
||||
return keypoints
|
||||
|
||||
|
||||
# load training embedding csv
|
||||
df = pd.read_csv('embeddings/basic-signs/embeddings.csv')
|
||||
|
||||
def minkowski_distance_p(x, y, p=2):
|
||||
x = np.asarray(x)
|
||||
y = np.asarray(y)
|
||||
|
||||
# Find smallest common datatype with float64 (return type of this
|
||||
# function) - addresses #10262.
|
||||
# Don't just cast to float64 for complex input case.
|
||||
common_datatype = np.promote_types(np.promote_types(x.dtype, y.dtype),
|
||||
'float64')
|
||||
|
||||
# Make sure x and y are NumPy arrays of correct datatype.
|
||||
x = x.astype(common_datatype)
|
||||
y = y.astype(common_datatype)
|
||||
|
||||
if p == np.inf:
|
||||
return np.amax(np.abs(y-x), axis=-1)
|
||||
elif p == 1:
|
||||
return np.sum(np.abs(y-x), axis=-1)
|
||||
else:
|
||||
return np.sum(np.abs(y-x)**p, axis=-1)
|
||||
|
||||
def minkowski_distance(x, y, p=2):
|
||||
x = np.asarray(x)
|
||||
y = np.asarray(y)
|
||||
if p == np.inf or p == 1:
|
||||
return minkowski_distance_p(x, y, p)
|
||||
else:
|
||||
return minkowski_distance_p(x, y, p)**(1./p)
|
||||
|
||||
|
||||
def distance_matrix(keypoints, embeddings, p=2, threshold=1000000):
|
||||
|
||||
x = np.array(keypoints)
|
||||
m, k = x.shape
|
||||
y = np.asarray(embeddings)
|
||||
n, kk = y.shape
|
||||
|
||||
if k != kk:
|
||||
raise ValueError(f"x contains {k}-dimensional vectors but y contains "
|
||||
f"{kk}-dimensional vectors")
|
||||
|
||||
if m*n*k <= threshold:
|
||||
return minkowski_distance(x[:,np.newaxis,:],y[np.newaxis,:,:],p)
|
||||
else:
|
||||
result = np.empty((m,n),dtype=float) # FIXME: figure out the best dtype
|
||||
if m < n:
|
||||
for i in range(m):
|
||||
result[i,:] = minkowski_distance(x[i],y,p)
|
||||
else:
|
||||
for j in range(n):
|
||||
result[:,j] = minkowski_distance(x,y[j],p)
|
||||
return result
|
||||
|
||||
|
||||
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
|
||||
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
||||
|
||||
model = SPOTER_EMBEDDINGS(
|
||||
features=checkpoint["config_args"].vector_length,
|
||||
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||
).to(device)
|
||||
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
embeddings = df.drop(columns=['labels', 'label_name', 'embeddings'])
|
||||
|
||||
# convert embedding from string to list of floats
|
||||
embeddings["embeddings"] = embeddings["embeddings2"].apply(lambda x: [float(i) for i in x[1:-1].split(", ")])
|
||||
# drop embeddings2
|
||||
embeddings = embeddings.drop(columns=['embeddings2'])
|
||||
# to list
|
||||
embeddings = embeddings["embeddings"].tolist()
|
||||
|
||||
def make_prediction(keypoints):
|
||||
# run model on frame
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
keypoints = torch.from_numpy(np.array([keypoints])).float().to(device)
|
||||
new_embeddings = model(keypoints).cpu().numpy().tolist()[0]
|
||||
|
||||
# calculate distance matrix
|
||||
dist_matrix = distance_matrix(new_embeddings, embeddings, p=2, threshold=1000000)
|
||||
|
||||
# get the 5 closest matches and select the class that is most common and use the average distance as the score
|
||||
# get the 5 closest matches
|
||||
indeces = np.argsort(dist_matrix)[0][:5]
|
||||
# get the labels
|
||||
labels = df["label_name"].iloc[indeces].tolist()
|
||||
c = Counter(labels).most_common()[0][0]
|
||||
|
||||
# filter indeces to only include the most common label
|
||||
indeces = [i for i in indeces if df["label_name"].iloc[i] == c]
|
||||
# get the average distance
|
||||
score = np.mean(dist_matrix[0][indeces])
|
||||
|
||||
return c, score
|
||||
|
||||
# open webcam stream
|
||||
cap = cv2.VideoCapture(0)
|
||||
|
||||
while cap.isOpened():
|
||||
# read frame
|
||||
ret, frame = cap.read()
|
||||
pose = extract_keypoints(frame)
|
||||
pose = predictor.extract_keypoints(frame)
|
||||
|
||||
if pose is None:
|
||||
cv2.imshow('MediaPipe Hands', frame)
|
||||
@ -310,29 +44,11 @@ while cap.isOpened():
|
||||
buffer.pop(0)
|
||||
|
||||
if len(buffer) == 15:
|
||||
label, score = make_prediction(buffer)
|
||||
label, score = predictor.make_prediction(buffer)
|
||||
|
||||
# draw label
|
||||
cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||
cv2.putText(frame, str(label), (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||
cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||
|
||||
# Show the frame
|
||||
cv2.imshow('MediaPipe Hands', frame)
|
||||
|
||||
# Wait for key press to exit
|
||||
if cv2.waitKey(5) & 0xFF == 27:
|
||||
break
|
||||
|
||||
# open video A.mp4
|
||||
# cap = cv2.VideoCapture('E.mp4')
|
||||
# while cap.isOpened():
|
||||
# # read frame
|
||||
# ret, frame = cap.read()
|
||||
# if frame is None:
|
||||
# break
|
||||
# pose = extract_keypoints(frame)
|
||||
|
||||
# buffer.append(pose)
|
||||
|
||||
# label, score = make_prediction(buffer)
|
||||
# print(label, score)
|
||||
Loading…
x
Reference in New Issue
Block a user