Merge branch 'FingerspellingEmbedding-+-ClearML' into 'main'
Fingerspelling embedding + ClearML See merge request wesign/spoterembedding!1
This commit was merged in pull request #1.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -155,3 +155,4 @@ out-img/
|
|||||||
converted_models/
|
converted_models/
|
||||||
*.pth
|
*.pth
|
||||||
*.onnx
|
*.onnx
|
||||||
|
.devcontainer
|
||||||
|
|||||||
20
README2.md
Normal file
20
README2.md
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# Spoter Embeddings
|
||||||
|
|
||||||
|
## Creating dataset
|
||||||
|
First, make a folder where all you're videos are located. When this is done, all keypoints can be extracted from the videos using the following command. This will extract the keypoints and store them in \<path-to-landmarks-folder\>.
|
||||||
|
```
|
||||||
|
python3 preprocessing.py extract --videos-folder <path-to-videos-folder> --output-landmark <path-to-landmarks-folder>
|
||||||
|
```
|
||||||
|
|
||||||
|
When this is done, the dataset can be created using the following command:
|
||||||
|
```
|
||||||
|
python3 preprocessing.py create --landmarks-dataset <path-to-landmarks-folder> --videos-folder <path-to-videos-folder> --dataset-folder <dataset-output-folder> (--create-new-split --test-size <test-percentage>)
|
||||||
|
```
|
||||||
|
The above command generates a train (and val) csv file which includes all the extracted keypoints. These can then be used to train or generates embeddings.
|
||||||
|
|
||||||
|
## Creating Embeddings
|
||||||
|
The embeddings can be created using the following command:
|
||||||
|
```
|
||||||
|
python3 export_embeddings.py --checkpoint <path-to-checkpoints-file> --dataset <path-to-dataset-file> --output <embeddings-output-file>
|
||||||
|
```
|
||||||
|
The command above generates the embeddings for a given dataset and saves them as a csv file.
|
||||||
BIN
checkpoints/checkpoint_embed_3006.pth
Normal file
BIN
checkpoints/checkpoint_embed_3006.pth
Normal file
Binary file not shown.
BIN
checkpoints/checkpoint_embed_3835.pth
Normal file
BIN
checkpoints/checkpoint_embed_3835.pth
Normal file
Binary file not shown.
@@ -17,5 +17,4 @@ requests==2.28.1
|
|||||||
onnx==1.12.0
|
onnx==1.12.0
|
||||||
onnx-tf==1.10.0
|
onnx-tf==1.10.0
|
||||||
onnxruntime==1.12.1
|
onnxruntime==1.12.1
|
||||||
tensorflow
|
coremltools==6.3.0
|
||||||
tensorflow-probability
|
|
||||||
91
embeddings/basic-signs/embeddings-online-dict.csv
Normal file
91
embeddings/basic-signs/embeddings-online-dict.csv
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
embeddings,label_name,labels,embeddings2
|
||||||
|
"[[ 0.01295076 -0.1074132 0.5111911 0.08452003 -0.7638924 0.6458221
|
||||||
|
-1.3892679 -1.1791427 -0.30736607 -0.41543546 -0.6358013 0.31411174
|
||||||
|
0.878936 1.7265923 -0.09298562 0.12005516 0.26995468 -1.5934305
|
||||||
|
-0.12619524 0.9111336 0.91827893 0.5948979 0.90334046 -0.84089845
|
||||||
|
-0.29575598 -0.3024254 -0.03074584 1.4402957 -0.22309914 0.54750854
|
||||||
|
-0.68767273 -0.0665718 ]]",BEDANKEN,0,"[0.012950755655765533, -0.10741320252418518, 0.5111911296844482, 0.08452003449201584, -0.763892412185669, 0.6458221077919006, -1.389267921447754, -1.179142713546753, -0.3073660731315613, -0.41543546319007874, -0.6358013153076172, 0.31411173939704895, 0.8789359927177429, 1.7265923023223877, -0.0929856151342392, 0.12005516141653061, 0.2699546813964844, -1.593430519104004, -0.12619523704051971, 0.9111335873603821, 0.9182789325714111, 0.5948979258537292, 0.9033404588699341, -0.8408984541893005, -0.2957559823989868, -0.30242541432380676, -0.030745841562747955, 1.440295696258545, -0.22309914231300354, 0.5475085377693176, -0.6876727342605591, -0.06657180190086365]"
|
||||||
|
"[[ 1.0287632 -0.94657993 0.7852507 -0.37169546 0.3614158 0.5360078
|
||||||
|
-0.9989285 -0.24402924 -0.47437504 -0.17307773 0.9196662 0.5362411
|
||||||
|
-0.50366765 0.7600024 0.6701926 -0.5288625 -0.38939306 -1.0722619
|
||||||
|
-1.33836 1.331829 -0.6958481 0.57767504 0.87002045 -0.59203666
|
||||||
|
-0.57576096 0.70008135 0.8224436 -0.32370126 0.5276206 -0.62150276
|
||||||
|
-0.9639951 -0.40879178]]",GOED,1,"[1.0287631750106812, -0.9465799331665039, 0.785250723361969, -0.37169545888900757, 0.3614158034324646, 0.536007821559906, -0.9989284873008728, -0.244029238820076, -0.47437503933906555, -0.17307773232460022, 0.9196661710739136, 0.5362411141395569, -0.5036676526069641, 0.7600023746490479, 0.6701925992965698, -0.528862476348877, -0.38939306139945984, -1.072261929512024, -1.3383599519729614, 1.3318289518356323, -0.6958481073379517, 0.5776750445365906, 0.8700204491615295, -0.5920366644859314, -0.5757609605789185, 0.7000813484191895, 0.8224436044692993, -0.32370126247406006, 0.5276206135749817, -0.6215027570724487, -0.963995099067688, -0.40879178047180176]"
|
||||||
|
"[[-0.5988377 -1.6454532 0.5696761 0.27801913 -0.39687645 1.5192578
|
||||||
|
-0.57908726 -1.9851501 1.1918951 -1.3280408 0.27390718 0.71642965
|
||||||
|
-1.0471189 0.9825219 -0.6625729 -0.28641602 2.0109527 -0.97667754
|
||||||
|
-2.3978348 2.666861 0.15066166 -1.6922158 0.9681514 -1.1413386
|
||||||
|
0.08027886 -0.6215762 -0.09759905 0.6157277 -0.01341684 0.9806912
|
||||||
|
-0.09793527 -1.9003383 ]]",GOEDEMIDDAG,2,"[-0.598837673664093, -1.6454532146453857, 0.5696761012077332, 0.27801913022994995, -0.3968764543533325, 1.5192577838897705, -0.5790872573852539, -1.9851500988006592, 1.1918951272964478, -1.3280408382415771, 0.2739071846008301, 0.7164296507835388, -1.047118902206421, 0.9825218915939331, -0.6625729203224182, -0.28641602396965027, 2.0109527111053467, -0.9766775369644165, -2.3978347778320312, 2.666861057281494, 0.1506616622209549, -1.6922158002853394, 0.9681513905525208, -1.141338586807251, 0.08027885854244232, -0.621576189994812, -0.09759905189275742, 0.6157277226448059, -0.013416841626167297, 0.9806911945343018, -0.0979352742433548, -1.9003382921218872]"
|
||||||
|
"[[-0.11578767 0.28371835 0.8604608 -0.31061298 -0.78153455 -0.24914108
|
||||||
|
-2.443886 -0.3267159 0.29237527 -0.7879326 -0.12082034 0.32809633
|
||||||
|
-1.0938975 0.7125962 0.49117383 -0.3119098 0.15369919 0.39569452
|
||||||
|
-1.18983 1.2498406 -0.38752007 -0.54896545 -0.620718 0.02654967
|
||||||
|
-0.10994279 -0.39713857 -0.18035033 -1.554728 0.12178666 -0.5698373
|
||||||
|
1.1097438 -1.2350351 ]]",GOEDEMORGEN,3,"[-0.11578767001628876, 0.2837183475494385, 0.8604608178138733, -0.3106129765510559, -0.7815345525741577, -0.24914108216762543, -2.4438860416412354, -0.326715886592865, 0.29237526655197144, -0.7879325747489929, -0.12082034349441528, 0.32809633016586304, -1.0938974618911743, 0.7125961780548096, 0.4911738336086273, -0.3119097948074341, 0.15369918942451477, 0.3956945240497589, -1.18982994556427, 1.2498406171798706, -0.38752007484436035, -0.5489654541015625, -0.6207180023193359, 0.02654966711997986, -0.10994279384613037, -0.3971385657787323, -0.18035033345222473, -1.5547280311584473, 0.12178666144609451, -0.5698372721672058, 1.1097438335418701, -1.2350350618362427]"
|
||||||
|
"[[-0.9557147 0.03368077 1.0923759 -0.41077584 -1.6992741 -0.59566665
|
||||||
|
-1.1562454 0.8824008 -0.25122 0.625515 1.6720039 0.5639413
|
||||||
|
-0.35894975 -0.05896414 0.84854823 0.4072139 -1.1763096 -1.5435666
|
||||||
|
-1.5066593 0.66299886 -1.3121625 1.4410669 -0.6341419 -0.4873513
|
||||||
|
0.8113033 -0.01502099 0.32262453 2.5855515 -0.7294207 0.7582124
|
||||||
|
-0.1386989 -0.17011487]]",GOEDENACHT,4,"[-0.9557147026062012, 0.03368076682090759, 1.0923758745193481, -0.4107758402824402, -1.6992740631103516, -0.5956666469573975, -1.1562453508377075, 0.8824008107185364, -0.2512199878692627, 0.6255149841308594, 1.6720038652420044, 0.5639412999153137, -0.35894975066185, -0.058964136987924576, 0.8485482335090637, 0.40721389651298523, -1.176309585571289, -1.5435665845870972, -1.5066592693328857, 0.6629988551139832, -1.3121625185012817, 1.441066861152649, -0.6341419219970703, -0.48735129833221436, 0.8113033175468445, -0.015020990744233131, 0.322624534368515, 2.5855515003204346, -0.7294207215309143, 0.7582123875617981, -0.13869890570640564, -0.17011487483978271]"
|
||||||
|
"[[ 0.13687058 -1.099074 1.1867181 -1.2216619 -0.965039 -0.6957289
|
||||||
|
-1.6400954 0.7317837 0.5887569 0.18756844 1.4867041 0.63357025
|
||||||
|
-1.5853169 0.4157976 0.76919526 0.08512082 -1.0937556 -0.07339763
|
||||||
|
-2.8580015 1.6835626 -1.4538742 1.2302468 -0.49323368 -0.81243044
|
||||||
|
-0.11378415 -0.09592562 0.31165385 -0.32617894 0.02981896 -0.01485211
|
||||||
|
1.1880591 -0.40726334]]",GOEDENAVOND,5,"[0.1368705779314041, -1.0990740060806274, 1.1867181062698364, -1.221661925315857, -0.9650390148162842, -0.6957288980484009, -1.6400953531265259, 0.7317836880683899, 0.5887569189071655, 0.18756844103336334, 1.4867041110992432, 0.6335702538490295, -1.5853168964385986, 0.4157975912094116, 0.7691952586174011, 0.08512081950902939, -1.0937556028366089, -0.07339762896299362, -2.858001470565796, 1.6835626363754272, -1.4538742303848267, 1.2302467823028564, -0.49323368072509766, -0.8124304413795471, -0.11378414928913116, -0.09592562168836594, 0.31165385246276855, -0.3261789381504059, 0.029818959534168243, -0.014852114021778107, 1.1880590915679932, -0.4072633385658264]"
|
||||||
|
"[[ 1.3673804 1.7083405 0.58775777 0.26056585 -0.29101595 -0.69954723
|
||||||
|
0.45681733 1.6908163 0.02684825 0.7537957 2.2054908 0.28831166
|
||||||
|
-1.5712252 -1.6869793 0.8009781 -0.51264 -1.3544708 -0.72502786
|
||||||
|
-0.31009498 -0.23777717 -1.7524906 1.6333568 1.5263942 0.22317217
|
||||||
|
-0.40576094 2.2136266 1.4030526 -1.0599502 0.8686069 -1.1141618
|
||||||
|
-0.01899157 -1.2256656 ]]",JA,6,"[1.3673803806304932, 1.7083405256271362, 0.5877577662467957, 0.260565847158432, -0.29101595282554626, -0.6995472311973572, 0.4568173289299011, 1.6908162832260132, 0.02684824913740158, 0.7537956833839417, 2.205490827560425, 0.2883116602897644, -1.5712251663208008, -1.6869792938232422, 0.8009781241416931, -0.5126399993896484, -1.3544708490371704, -0.725027859210968, -0.3100949823856354, -0.23777717351913452, -1.7524906396865845, 1.6333568096160889, 1.526394248008728, 0.2231721729040146, -0.4057609438896179, 2.2136266231536865, 1.403052568435669, -1.0599502325057983, 0.8686069250106812, -1.1141618490219116, -0.01899157091975212, -1.22566556930542]"
|
||||||
|
"[[ 1.8208723 -0.59625745 0.15060106 -0.23053254 1.0344827 1.7016335
|
||||||
|
0.1968202 -0.48188782 0.3004466 -0.53336793 0.89704275 0.26869062
|
||||||
|
-1.0112952 0.0343281 0.32149622 -0.8189871 0.33571526 -0.57059044
|
||||||
|
-0.7577773 1.4743904 -0.01735806 -0.4116615 2.3637483 -0.6602343
|
||||||
|
-0.6101709 1.8139238 1.041902 -0.9006022 0.6082014 -0.23616463
|
||||||
|
-0.62204087 -0.524899 ]]",LINKS,7,"[1.8208723068237305, -0.5962574481964111, 0.15060105919837952, -0.23053254187107086, 1.034482717514038, 1.7016334533691406, 0.1968201994895935, -0.4818878173828125, 0.30044659972190857, -0.533367931842804, 0.8970427513122559, 0.2686906158924103, -1.011295199394226, 0.034328099340200424, 0.32149621844291687, -0.8189870715141296, 0.33571526408195496, -0.5705904364585876, -0.7577772736549377, 1.4743903875350952, -0.01735806092619896, -0.4116615056991577, 2.36374831199646, -0.660234272480011, -0.6101709008216858, 1.8139238357543945, 1.04190194606781, -0.9006022214889526, 0.6082013845443726, -0.2361646294593811, -0.622040867805481, -0.5248990058898926]"
|
||||||
|
"[[ 1.3968729 2.340378 0.567814 0.5684975 -0.8795973 -1.0090083
|
||||||
|
-0.01301649 1.9747832 -0.3257468 0.76960254 1.402801 0.39424965
|
||||||
|
0.06948688 -0.51904166 1.0961802 -0.34335235 -1.8730735 -1.0801809
|
||||||
|
0.2751789 -0.96998405 -1.9822828 2.1046805 1.3094746 0.5770473
|
||||||
|
-0.19693638 1.8973799 0.99536693 0.2668684 0.28499565 -0.9838486
|
||||||
|
-0.19578832 -0.46982569]]",NEE,8,"[1.396872878074646, 2.3403780460357666, 0.5678139925003052, 0.5684974789619446, -0.8795973062515259, -1.0090082883834839, -0.013016488403081894, 1.974783182144165, -0.3257468044757843, 0.7696025371551514, 1.4028010368347168, 0.39424964785575867, 0.06948687881231308, -0.5190416574478149, 1.0961802005767822, -0.343352347612381, -1.8730734586715698, -1.0801808834075928, 0.2751789093017578, -0.9699840545654297, -1.9822827577590942, 2.1046805381774902, 1.3094745874404907, 0.5770472884178162, -0.19693638384342194, 1.8973798751831055, 0.9953669309616089, 0.2668684124946594, 0.2849956452846527, -0.9838485717773438, -0.19578832387924194, -0.4698256850242615]"
|
||||||
|
"[[ 1.8143647 -0.9315294 0.28706568 -1.0938002 -0.2451038 0.42331484
|
||||||
|
-0.31653756 0.00268575 0.89822304 1.0639151 1.7406261 1.1707302
|
||||||
|
-1.6639041 0.30558044 0.22984704 -1.0260515 -0.08818819 -0.14696571
|
||||||
|
-1.5269523 0.55015475 -0.12936743 0.93951946 1.3917109 -0.62517196
|
||||||
|
-1.3869016 1.4833041 1.4647726 -1.0285482 -0.3704591 1.4672718
|
||||||
|
0.40315557 0.04754414]]",RECHTS,9,"[1.8143646717071533, -0.9315294027328491, 0.28706568479537964, -1.0938001871109009, -0.24510380625724792, 0.4233148396015167, -0.3165375590324402, 0.002685748040676117, 0.8982230424880981, 1.0639151334762573, 1.7406260967254639, 1.1707302331924438, -1.663904070854187, 0.3055804371833801, 0.2298470437526703, -1.0260515213012695, -0.08818819373846054, -0.14696571230888367, -1.5269522666931152, 0.5501547455787659, -0.1293674260377884, 0.939519464969635, 1.391710877418518, -0.625171959400177, -1.386901617050171, 1.4833041429519653, 1.4647725820541382, -1.028548240661621, -0.37045910954475403, 1.4672718048095703, 0.4031555652618408, 0.04754413664340973]"
|
||||||
|
"[[ 1.681777 -0.79441184 0.07110912 -0.01304399 1.110081 2.1115103
|
||||||
|
-0.27283993 -1.5035654 0.55927217 -0.9903696 0.05718866 0.19891238
|
||||||
|
-0.7289791 0.71738267 -0.41428873 -0.58389515 1.1087953 -0.6616588
|
||||||
|
-0.6330258 1.5830203 0.60013187 -1.0640136 2.3917787 -0.70012283
|
||||||
|
-0.7359694 0.9847692 0.94458187 -0.5915589 0.3662199 0.39359796
|
||||||
|
-0.6499281 -0.30565363]]",SALUUT,10,"[1.681777000427246, -0.794411838054657, 0.07110912352800369, -0.013043994084000587, 1.1100809574127197, 2.1115102767944336, -0.2728399336338043, -1.5035654306411743, 0.5592721700668335, -0.9903696179389954, 0.05718865618109703, 0.1989123821258545, -0.7289791107177734, 0.7173826694488525, -0.414288729429245, -0.5838951468467712, 1.1087952852249146, -0.6616588234901428, -0.6330258250236511, 1.5830203294754028, 0.6001318693161011, -1.0640136003494263, 2.3917787075042725, -0.7001228332519531, -0.7359694242477417, 0.9847692251205444, 0.9445818662643433, -0.5915588736534119, 0.3662199079990387, 0.39359796047210693, -0.649928092956543, -0.3056536316871643]"
|
||||||
|
"[[ 0.56773967 -0.6260798 0.23771092 -0.10016712 0.86817515 0.92371875
|
||||||
|
-0.16313902 -0.3785063 -0.41955388 -0.586288 -0.33712262 -0.07999519
|
||||||
|
-0.98128295 0.18787864 -0.36432508 -0.06605241 0.00710656 -0.36359936
|
||||||
|
-0.00642961 1.257384 0.64647824 -1.1618687 1.3707273 -0.46546304
|
||||||
|
-0.14522405 0.29306158 0.41898865 -0.12311607 0.24604419 -0.48434448
|
||||||
|
-0.79076517 0.753163 ]]",SLECHT,11,"[0.5677396655082703, -0.626079797744751, 0.23771092295646667, -0.10016711801290512, 0.8681751489639282, 0.9237187504768372, -0.16313901543617249, -0.37850630283355713, -0.41955387592315674, -0.5862879753112793, -0.3371226191520691, -0.07999519258737564, -0.9812829494476318, 0.18787863850593567, -0.36432507634162903, -0.06605241447687149, 0.007106557488441467, -0.36359935998916626, -0.006429608445614576, 1.257383942604065, 0.6464782357215881, -1.161868691444397, 1.370727300643921, -0.4654630422592163, -0.14522404968738556, 0.29306158423423767, 0.4189886450767517, -0.12311606854200363, 0.24604418873786926, -0.484344482421875, -0.7907651662826538, 0.7531629800796509]"
|
||||||
|
"[[ 0.9251696 -2.8536968 0.6521715 -1.8256427 0.44136596 0.33299872
|
||||||
|
-1.4674346 -0.7897836 0.47153538 0.1437631 0.7396151 1.0851275
|
||||||
|
-1.2100328 1.626544 0.69652647 0.0048107 0.13927785 -0.63199776
|
||||||
|
-3.0121155 2.1738136 -0.7935879 0.2744053 0.281097 -1.0899199
|
||||||
|
-0.7062637 -0.04824958 0.3436054 -0.69366074 0.22799876 0.8932708
|
||||||
|
0.51823634 -0.08065546]]",SMAKELIJK,12,"[0.9251695871353149, -2.853696823120117, 0.6521714925765991, -1.825642704963684, 0.44136595726013184, 0.33299872279167175, -1.4674346446990967, -0.7897835969924927, 0.4715353846549988, 0.14376309514045715, 0.7396150827407837, 1.0851274728775024, -1.2100328207015991, 1.6265439987182617, 0.6965264678001404, 0.004810698330402374, 0.139277845621109, -0.6319977641105652, -3.012115478515625, 2.173813581466675, -0.7935879230499268, 0.274405300617218, 0.2810969948768616, -1.089919924736023, -0.7062637209892273, -0.04824957996606827, 0.3436053991317749, -0.6936607360839844, 0.2279987633228302, 0.8932707905769348, 0.5182363390922546, -0.08065545558929443]"
|
||||||
|
"[[-1.460053 2.9393604 0.5533726 1.3786266 -1.5367306 -1.2867874
|
||||||
|
-1.2442408 0.62938243 -0.7092637 -1.0278705 -1.3296479 -0.8105969
|
||||||
|
0.69997054 0.16618343 0.5113248 0.15563956 -0.4815752 -0.10130378
|
||||||
|
1.4232539 -1.2767068 -0.3229611 0.45140803 -1.3901687 1.0403453
|
||||||
|
1.454544 -1.308266 -1.145199 0.661388 -0.20519465 -1.0480955
|
||||||
|
-0.04290716 -0.36275834]]",SORRY,13,"[-1.4600529670715332, 2.9393603801727295, 0.5533726215362549, 1.3786265850067139, -1.5367306470870972, -1.2867873907089233, -1.2442407608032227, 0.6293824315071106, -0.7092636823654175, -1.027870535850525, -1.3296478986740112, -0.8105968832969666, 0.699970543384552, 0.16618342697620392, 0.5113248229026794, 0.15563955903053284, -0.4815751910209656, -0.10130377858877182, 1.4232538938522339, -1.2767068147659302, -0.32296109199523926, 0.4514080286026001, -1.3901686668395996, 1.040345311164856, 1.454543948173523, -1.308266043663025, -1.145198941230774, 0.6613879799842834, -0.2051946520805359, -1.048095464706421, -0.04290715977549553, -0.3627583384513855]"
|
||||||
|
"[[ 1.0733138 -0.66348886 0.36737278 0.01765811 0.5730918 1.7617474
|
||||||
|
-0.45580968 -1.0973133 0.4730748 -0.4665998 0.48594972 0.548745
|
||||||
|
-0.91712314 0.4846773 0.3774526 -0.5996147 0.28750947 -1.0166394
|
||||||
|
-0.6239797 1.7935088 -0.03666669 -0.51941574 2.045076 -0.9045104
|
||||||
|
-0.6312211 0.9698636 1.0522215 -0.14772844 0.7406032 0.2901447
|
||||||
|
-0.48710442 -0.4619298 ]]",TOT-ZIENS,14,"[1.07331383228302, -0.6634888648986816, 0.3673727810382843, 0.017658105120062828, 0.5730918049812317, 1.7617473602294922, -0.45580968260765076, -1.0973132848739624, 0.4730747938156128, -0.46659979224205017, 0.48594972491264343, 0.5487449765205383, -0.9171231389045715, 0.4846773147583008, 0.3774526119232178, -0.599614679813385, 0.2875094711780548, -1.0166393518447876, -0.6239796876907349, 1.793508768081665, -0.036666687577962875, -0.5194157361984253, 2.0450758934020996, -0.9045103788375854, -0.6312211155891418, 0.9698635935783386, 1.0522215366363525, -0.14772844314575195, 0.7406032085418701, 0.2901447117328644, -0.4871044158935547, -0.4619297981262207]"
|
||||||
|
97
export_embeddings.py
Normal file
97
export_embeddings.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
from datasets.dataset_loader import LocalDatasetLoader
|
||||||
|
from datasets.embedding_dataset import SLREmbeddingDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from datasets import SLREmbeddingDataset, collate_fn_padd
|
||||||
|
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
seed = 43
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.use_deterministic_algorithms(True)
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
|
||||||
|
def seed_worker(worker_id):
|
||||||
|
worker_seed = torch.initial_seed() % 2**32
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
import os
|
||||||
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Export embeddings')
|
||||||
|
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
|
||||||
|
parser.add_argument('--output', type=str, default=None, help='Path to output')
|
||||||
|
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
|
||||||
|
parser.add_argument('--format', type=str, default='csv', help='Format of the output file (csv, json)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
# load the model
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location=device)
|
||||||
|
|
||||||
|
model = SPOTER_EMBEDDINGS(
|
||||||
|
features=checkpoint["config_args"].vector_length,
|
||||||
|
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||||
|
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
model.load_state_dict(checkpoint["state_dict"])
|
||||||
|
|
||||||
|
dataset_loader = LocalDatasetLoader()
|
||||||
|
dataset = SLREmbeddingDataset(args.dataset, triplet=False, augmentations=False)
|
||||||
|
|
||||||
|
data_loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=collate_fn_padd,
|
||||||
|
pin_memory=torch.cuda.is_available(),
|
||||||
|
#num_workers=0, # Uncomment this line (and comment out next line) if you want to disable multithreading
|
||||||
|
num_workers=multiprocessing.cpu_count(),
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
k = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, (inputs, labels, masks) in enumerate(data_loader):
|
||||||
|
k += 1
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
masks = masks.to(device)
|
||||||
|
outputs = model(inputs, masks)
|
||||||
|
|
||||||
|
for n in range(outputs.shape[0]):
|
||||||
|
embeddings.append(outputs[n].cpu().numpy())
|
||||||
|
|
||||||
|
df = pd.read_csv(args.dataset)
|
||||||
|
df["embeddings"] = embeddings
|
||||||
|
df = df[['embeddings', 'label_name', 'labels']]
|
||||||
|
df['embeddings'] = df['embeddings'].apply(lambda x: x.tolist()[0])
|
||||||
|
|
||||||
|
if args.format == 'json':
|
||||||
|
df.to_json(args.output, orient='records')
|
||||||
|
elif args.format == 'csv':
|
||||||
|
df.to_csv(args.output, index=False)
|
||||||
67
export_model.py
Normal file
67
export_model.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# to run this script, you need torch 1.13.1 and torchvision 0.14.1
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import os
|
||||||
|
|
||||||
|
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
|
||||||
|
|
||||||
|
# set parameters of the model
|
||||||
|
model_name = 'fingerspelling_embedding_model'
|
||||||
|
|
||||||
|
# load PyTorch model from .pth file
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
# if torch.cuda.is_available():
|
||||||
|
# device = torch.device("cuda")
|
||||||
|
|
||||||
|
CHECKPOINT_PATH = "checkpoints/fingerspelling_checkpoint.pth"
|
||||||
|
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
||||||
|
|
||||||
|
model = SPOTER_EMBEDDINGS(
|
||||||
|
features=checkpoint["config_args"].vector_length,
|
||||||
|
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||||
|
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||||
|
).to(device)
|
||||||
|
model.load_state_dict(checkpoint["state_dict"])
|
||||||
|
# set model to evaluation mode
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dummy_input = torch.randn(1, 10, 54, 2)
|
||||||
|
|
||||||
|
# check if models folder exists
|
||||||
|
if not os.path.exists('out-models'):
|
||||||
|
os.makedirs('out-models')
|
||||||
|
|
||||||
|
for model_export in ["onnx", "coreml"]:
|
||||||
|
if model_export == "coreml":
|
||||||
|
# set device for dummy input
|
||||||
|
dummy_input = dummy_input.to(device)
|
||||||
|
traced_model = torch.jit.trace(model, dummy_input)
|
||||||
|
|
||||||
|
out = traced_model(dummy_input)
|
||||||
|
import coremltools as ct
|
||||||
|
|
||||||
|
# Convert to Core ML
|
||||||
|
coreml_model = ct.convert(
|
||||||
|
traced_model,
|
||||||
|
inputs=[ct.TensorType(name="input", shape=dummy_input.shape)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save Core ML model
|
||||||
|
coreml_model.save("out-models/" + model_name + ".mlmodel")
|
||||||
|
else:
|
||||||
|
# set device for dummy input
|
||||||
|
dummy_input = dummy_input.to(device)
|
||||||
|
|
||||||
|
torch.onnx.export(model, # model being run
|
||||||
|
dummy_input, # model input (or a tuple for multiple inputs)
|
||||||
|
'out-models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object)
|
||||||
|
export_params=True, # store the trained parameter weights inside the model file
|
||||||
|
opset_version=9, # the ONNX version to export the model to
|
||||||
|
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||||
|
input_names = ['X'], # the model's input names
|
||||||
|
output_names = ['Y'] # the model's output names
|
||||||
|
)
|
||||||
73
hyperparam_opt.py
Normal file
73
hyperparam_opt.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from clearml.automation import UniformParameterRange, UniformIntegerParameterRange, DiscreteParameterRange
|
||||||
|
from clearml.automation import HyperParameterOptimizer
|
||||||
|
from clearml.automation.optuna import OptimizerOptuna
|
||||||
|
from optuna.pruners import HyperbandPruner, MedianPruner
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
task = Task.init(
|
||||||
|
project_name='SpoterEmbedding',
|
||||||
|
task_name='Automatic Hyper-Parameter Optimization',
|
||||||
|
task_type=Task.TaskTypes.optimizer,
|
||||||
|
reuse_last_task_id=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
optimizer = HyperParameterOptimizer(
|
||||||
|
# specifying the task to be optimized, task must be in system already so it can be cloned
|
||||||
|
base_task_id="4504e0b3ec6745249d3d4c94d3d40652",
|
||||||
|
# setting the hyperparameters to optimize
|
||||||
|
hyper_parameters=[
|
||||||
|
# epochs:
|
||||||
|
DiscreteParameterRange('Args/epochs', [200]),
|
||||||
|
# learning rate
|
||||||
|
UniformParameterRange('Args/lr', 0.000001, 0.01),
|
||||||
|
# optimizer
|
||||||
|
DiscreteParameterRange('Args/optimizer', ['ADAM', 'SGD']),
|
||||||
|
# vector length
|
||||||
|
UniformIntegerParameterRange('Args/vector_length', 10, 100),
|
||||||
|
],
|
||||||
|
# setting the objective metric we want to maximize/minimize
|
||||||
|
objective_metric_title='train_loss',
|
||||||
|
objective_metric_series='loss',
|
||||||
|
objective_metric_sign='min',
|
||||||
|
|
||||||
|
# setting optimizer
|
||||||
|
optimizer_class=OptimizerOptuna,
|
||||||
|
|
||||||
|
# configuring optimization parameters
|
||||||
|
execution_queue='default',
|
||||||
|
optimization_time_limit=360,
|
||||||
|
compute_time_limit=480,
|
||||||
|
total_max_jobs=20,
|
||||||
|
min_iteration_per_job=0,
|
||||||
|
max_iteration_per_job=150000,
|
||||||
|
pool_period_min=0.1,
|
||||||
|
|
||||||
|
save_top_k_tasks_only=3,
|
||||||
|
optuna_pruner=MedianPruner(),
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
def job_complete_callback(
|
||||||
|
job_id, # type: str
|
||||||
|
objective_value, # type: float
|
||||||
|
objective_iteration, # type: int
|
||||||
|
job_parameters, # type: dict
|
||||||
|
top_performance_job_id # type: str
|
||||||
|
):
|
||||||
|
print('Job completed!', job_id, objective_value, objective_iteration, job_parameters)
|
||||||
|
if job_id == top_performance_job_id:
|
||||||
|
print('WOOT WOOT we broke the record! Objective reached {}'.format(objective_value))
|
||||||
|
|
||||||
|
task.execute_remotely(queue_name='hypertuning', exit_process=True)
|
||||||
|
|
||||||
|
optimizer.set_report_period(0.3)
|
||||||
|
|
||||||
|
optimizer.start(job_complete_callback=job_complete_callback)
|
||||||
|
|
||||||
|
optimizer.wait()
|
||||||
|
|
||||||
|
top_exp = optimizer.get_top_experiments(top_k=3)
|
||||||
|
print([t.id for t in top_exp])
|
||||||
|
|
||||||
|
optimizer.stop()
|
||||||
@@ -88,6 +88,7 @@ def train_epoch_embedding_online(model, epoch_iters, train_loader, val_loader, c
|
|||||||
if enable_batch_sorting:
|
if enable_batch_sorting:
|
||||||
if labels_size < train_loader.batch_size:
|
if labels_size < train_loader.batch_size:
|
||||||
trim_count = labels_size % mini_batch
|
trim_count = labels_size % mini_batch
|
||||||
|
if trim_count > 0:
|
||||||
inputs = inputs[:-trim_count]
|
inputs = inputs[:-trim_count]
|
||||||
labels = labels[:-trim_count]
|
labels = labels[:-trim_count]
|
||||||
masks = masks[:-trim_count]
|
masks = masks[:-trim_count]
|
||||||
|
|||||||
@@ -61,20 +61,25 @@ def map_blazepose_keypoint(column):
|
|||||||
return f"{mapped}_{hand}{suffix}"
|
return f"{mapped}_{hand}{suffix}"
|
||||||
|
|
||||||
|
|
||||||
def map_blazepose_df(df):
|
def map_blazepose_df(df, rename=True):
|
||||||
|
to_drop = []
|
||||||
|
if rename:
|
||||||
|
renamings = {}
|
||||||
|
for column in df.columns:
|
||||||
|
mapped_column = map_blazepose_keypoint(column)
|
||||||
|
if mapped_column:
|
||||||
|
renamings[column] = mapped_column
|
||||||
|
else:
|
||||||
|
to_drop.append(column)
|
||||||
|
df = df.rename(columns=renamings)
|
||||||
|
|
||||||
for index, row in df.iterrows():
|
for index, row in df.iterrows():
|
||||||
|
|
||||||
|
sequence_size = len(row["leftEar_Y"])
|
||||||
lsx = row["leftShoulder_X"]
|
lsx = row["leftShoulder_X"]
|
||||||
rsx = row["rightShoulder_X"]
|
rsx = row["rightShoulder_X"]
|
||||||
lsy = row["leftShoulder_Y"]
|
lsy = row["leftShoulder_Y"]
|
||||||
rsy = row["rightShoulder_Y"]
|
rsy = row["rightShoulder_Y"]
|
||||||
# convert all to list
|
|
||||||
lsx = lsx[1:-1].split(",")
|
|
||||||
rsx = rsx[1:-1].split(",")
|
|
||||||
lsy = lsy[1:-1].split(",")
|
|
||||||
rsy = rsy[1:-1].split(",")
|
|
||||||
sequence_size = len(lsx)
|
|
||||||
|
|
||||||
neck_x = []
|
neck_x = []
|
||||||
neck_y = []
|
neck_y = []
|
||||||
# Treat each element of the sequence (analyzed frame) individually
|
# Treat each element of the sequence (analyzed frame) individually
|
||||||
@@ -84,4 +89,5 @@ def map_blazepose_df(df):
|
|||||||
df.loc[index, "neck_X"] = str(neck_x)
|
df.loc[index, "neck_X"] = str(neck_x)
|
||||||
df.loc[index, "neck_Y"] = str(neck_y)
|
df.loc[index, "neck_Y"] = str(neck_y)
|
||||||
|
|
||||||
|
df.drop(columns=to_drop, inplace=True)
|
||||||
return df
|
return df
|
||||||
@@ -5,30 +5,30 @@ import pandas as pd
|
|||||||
from normalization.hand_normalization import normalize_hands_full
|
from normalization.hand_normalization import normalize_hands_full
|
||||||
from normalization.body_normalization import normalize_body_full
|
from normalization.body_normalization import normalize_body_full
|
||||||
|
|
||||||
DATASET_PATH = './data/wlasl'
|
DATASET_PATH = './data/processed'
|
||||||
# Load the dataset
|
# Load the dataset
|
||||||
df = pd.read_csv(os.path.join(DATASET_PATH, "WLASL100_train.csv"), encoding="utf-8")
|
df = pd.read_csv(os.path.join(DATASET_PATH, "spoter_train.csv"), encoding="utf-8")
|
||||||
|
|
||||||
print(df.head())
|
print(df.head())
|
||||||
print(df.columns)
|
print(df.columns)
|
||||||
|
|
||||||
# Retrieve metadata
|
# Retrieve metadata
|
||||||
video_size_heights = df["video_height"].to_list()
|
# video_size_heights = df["video_height"].to_list()
|
||||||
video_size_widths = df["video_width"].to_list()
|
# video_size_widths = df["video_width"].to_list()
|
||||||
|
|
||||||
# Delete redundant (non-related) properties
|
# Delete redundant (non-related) properties
|
||||||
del df["video_height"]
|
# del df["video_height"]
|
||||||
del df["video_width"]
|
# del df["video_width"]
|
||||||
|
|
||||||
# Temporarily remove other relevant metadata
|
# Temporarily remove other relevant metadata
|
||||||
labels = df["labels"].to_list()
|
labels = df["labels"].to_list()
|
||||||
video_fps = df["fps"].to_list()
|
signs = df["sign"].to_list()
|
||||||
|
|
||||||
del df["labels"]
|
del df["labels"]
|
||||||
del df["fps"]
|
del df["sign"]
|
||||||
del df["split"]
|
del df["path"]
|
||||||
del df["video_id"]
|
del df["participant_id"]
|
||||||
del df["label_name"]
|
del df["sequence_id"]
|
||||||
del df["length"]
|
|
||||||
|
|
||||||
# Convert the strings into lists
|
# Convert the strings into lists
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ for column in df.columns:
|
|||||||
|
|
||||||
# Perform the normalizations
|
# Perform the normalizations
|
||||||
df = normalize_hands_full(df)
|
df = normalize_hands_full(df)
|
||||||
df, invalid_row_indexes = normalize_body_full(df)
|
# df, invalid_row_indexes = normalize_body_full(df)
|
||||||
|
|
||||||
# Clear lists of items from deleted rows
|
# Clear lists of items from deleted rows
|
||||||
# labels = [t for i, t in enumerate(labels) if i not in invalid_row_indexes]
|
# labels = [t for i, t in enumerate(labels) if i not in invalid_row_indexes]
|
||||||
@@ -49,6 +49,6 @@ df, invalid_row_indexes = normalize_body_full(df)
|
|||||||
|
|
||||||
# Return the metadata back to the dataset
|
# Return the metadata back to the dataset
|
||||||
df["labels"] = labels
|
df["labels"] = labels
|
||||||
df["fps"] = video_fps
|
df["sign"] = signs
|
||||||
|
|
||||||
df.to_csv(os.path.join(DATASET_PATH, "wlasl_train_norm.csv"), encoding="utf-8", index=False)
|
df.to_csv(os.path.join(DATASET_PATH, "spoter_train_norm.csv"), encoding="utf-8", index=False)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 5,
|
||||||
"id": "c20f7fd5",
|
"id": "c20f7fd5",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 6,
|
||||||
"id": "ada032d0",
|
"id": "ada032d0",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -22,13 +22,12 @@
|
|||||||
"import os\n",
|
"import os\n",
|
||||||
"import os.path as op\n",
|
"import os.path as op\n",
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"import json\n",
|
"import json"
|
||||||
"import base64"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 7,
|
||||||
"id": "05682e73",
|
"id": "05682e73",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -38,7 +37,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 8,
|
||||||
"id": "fede7684",
|
"id": "fede7684",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -48,7 +47,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 9,
|
||||||
"id": "ce531994",
|
"id": "ce531994",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -64,7 +63,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 10,
|
||||||
"id": "f4a2d672",
|
"id": "f4a2d672",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -87,17 +86,17 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 11,
|
||||||
"id": "1d9db764",
|
"id": "1d9db764",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"<torch._C.Generator at 0x7f29f89e3ed0>"
|
"<torch._C.Generator at 0x7f010919d710>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 7,
|
"execution_count": 11,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -119,7 +118,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 12,
|
||||||
"id": "71224139",
|
"id": "71224139",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -155,7 +154,7 @@
|
|||||||
"# checkpoint = torch.load(model.get_weights())\n",
|
"# checkpoint = torch.load(model.get_weights())\n",
|
||||||
"\n",
|
"\n",
|
||||||
"## Set your path to checkoint here\n",
|
"## Set your path to checkoint here\n",
|
||||||
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_6.pth\"\n",
|
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_1105.pth\"\n",
|
||||||
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
|
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = SPOTER_EMBEDDINGS(\n",
|
"model = SPOTER_EMBEDDINGS(\n",
|
||||||
@@ -169,27 +168,28 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 24,
|
||||||
"id": "ba6b58f0",
|
"id": "ba6b58f0",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"SL_DATASET = 'wlasl' # or 'lsa'\n",
|
"SL_DATASET = 'basic-signs' # or 'wlasl'\n",
|
||||||
"if SL_DATASET == 'wlasl':\n",
|
"\n",
|
||||||
|
"if SL_DATASET == 'fingerspelling':\n",
|
||||||
|
" dataset_name = \"fingerspelling\"\n",
|
||||||
|
" split_dataset_path = \"fingerspelling_{}.csv\"\n",
|
||||||
|
"elif SL_DATASET == 'wlasl':\n",
|
||||||
" dataset_name = \"wlasl\"\n",
|
" dataset_name = \"wlasl\"\n",
|
||||||
" num_classes = 100\n",
|
" split_dataset_path = \"WLASL100_{}.csv\"\n",
|
||||||
" split_dataset_path = \"WLASL100_train.csv\"\n",
|
"elif SL_DATASET == 'basic-signs':\n",
|
||||||
"else:\n",
|
" dataset_name = \"basic-signs\"\n",
|
||||||
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
|
" split_dataset_path = \"basic-signs_{}.csv\"\n",
|
||||||
" num_classes = 64\n",
|
|
||||||
" split_dataset_path = \"LSA64_{}.csv\"\n",
|
|
||||||
" \n",
|
|
||||||
" "
|
" "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 25,
|
||||||
"id": "5643a72c",
|
"id": "5643a72c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -269,6 +269,8 @@
|
|||||||
" for i, (inputs, labels, masks) in enumerate(dataloader):\n",
|
" for i, (inputs, labels, masks) in enumerate(dataloader):\n",
|
||||||
" k += 1\n",
|
" k += 1\n",
|
||||||
" inputs = inputs.to(device)\n",
|
" inputs = inputs.to(device)\n",
|
||||||
|
" \n",
|
||||||
|
"\n",
|
||||||
" masks = masks.to(device)\n",
|
" masks = masks.to(device)\n",
|
||||||
" outputs = model(inputs, masks)\n",
|
" outputs = model(inputs, masks)\n",
|
||||||
" for n in range(outputs.shape[0]):\n",
|
" for n in range(outputs.shape[0]):\n",
|
||||||
@@ -285,7 +287,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"(810, 810)"
|
"(164, 164)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 19,
|
"execution_count": 19,
|
||||||
@@ -299,7 +301,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 21,
|
||||||
"id": "ab83c6e2",
|
"id": "ab83c6e2",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"lines_to_next_cell": 2
|
"lines_to_next_cell": 2
|
||||||
@@ -311,6 +313,70 @@
|
|||||||
" df['embeddings'] = embeddings_split[split]"
|
" df['embeddings'] = embeddings_split[split]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 26,
|
||||||
|
"id": "0b9fb9c2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"0 [1.7327625, -3.015248, -1.4775522, -0.7505071,...\n",
|
||||||
|
"1 [2.0936582, -0.596195, -0.7918601, -0.15896143...\n",
|
||||||
|
"2 [-1.4007742, -0.9608915, 1.3294879, -0.5185398...\n",
|
||||||
|
"3 [1.3280737, -3.299126, -1.0110444, -1.2528414,...\n",
|
||||||
|
"4 [-0.071124956, -0.79259753, 0.7182858, 0.38130...\n",
|
||||||
|
" ... \n",
|
||||||
|
"159 [-1.5968355, 1.9617733, 0.28859574, 1.256657, ...\n",
|
||||||
|
"160 [0.44801116, -1.8377966, 1.1004394, -1.195648,...\n",
|
||||||
|
"161 [2.0584257, 1.6986116, 0.5129896, 0.27279535, ...\n",
|
||||||
|
"162 [1.6695516, -2.967027, -1.5715427, -0.77170163...\n",
|
||||||
|
"163 [1.4977738, -2.6278958, -1.6123883, -0.8420623...\n",
|
||||||
|
"Name: embeddings, Length: 164, dtype: object\n",
|
||||||
|
"0 TOT-ZIENS\n",
|
||||||
|
"1 GOED\n",
|
||||||
|
"2 GOEDENACHT\n",
|
||||||
|
"3 NEE\n",
|
||||||
|
"4 SLECHT\n",
|
||||||
|
" ... \n",
|
||||||
|
"159 SORRY\n",
|
||||||
|
"160 GOEDEMORGEN\n",
|
||||||
|
"161 LINKS\n",
|
||||||
|
"162 TOT-ZIENS\n",
|
||||||
|
"163 GOED\n",
|
||||||
|
"Name: label_name, Length: 164, dtype: object\n",
|
||||||
|
"0 0\n",
|
||||||
|
"1 1\n",
|
||||||
|
"2 2\n",
|
||||||
|
"3 3\n",
|
||||||
|
"4 4\n",
|
||||||
|
" ..\n",
|
||||||
|
"159 7\n",
|
||||||
|
"160 5\n",
|
||||||
|
"161 13\n",
|
||||||
|
"162 0\n",
|
||||||
|
"163 1\n",
|
||||||
|
"Name: labels, Length: 164, dtype: int64\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(dfs['train'][\"embeddings\"])\n",
|
||||||
|
"print(dfs['train'][\"label_name\"])\n",
|
||||||
|
"print(dfs['train'][\"labels\"])\n",
|
||||||
|
"\n",
|
||||||
|
"# only keep these columns\n",
|
||||||
|
"dfs['train'] = dfs['train'][['embeddings', 'label_name', 'labels']]\n",
|
||||||
|
"\n",
|
||||||
|
"# convert embeddings to string\n",
|
||||||
|
"dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n",
|
||||||
|
"\n",
|
||||||
|
"# save the dfs['train']\n",
|
||||||
|
"dfs['train'].to_csv(f'../data/{dataset_name}/embeddings.csv', index=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "2951638d",
|
"id": "2951638d",
|
||||||
@@ -322,7 +388,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 23,
|
||||||
"id": "7399b8ae",
|
"id": "7399b8ae",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -331,16 +397,16 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Using centroids only\n",
|
"Using centroids only\n",
|
||||||
"Top-1 accuracy: 5.19 %\n",
|
"Top-1 accuracy: 80.00 %\n",
|
||||||
"Top-5 embeddings class match: 17.65 % (Picks any class in the 5 closest embeddings)\n",
|
"Top-5 embeddings class match: 93.33 % (Picks any class in the 5 closest embeddings)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"################################\n",
|
"################################\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Using all embeddings\n",
|
"Using all embeddings\n",
|
||||||
"Top-1 accuracy: 5.31 %\n",
|
"Top-1 accuracy: 80.00 %\n",
|
||||||
"5-nn accuracy: 5.56 % (Picks the class that appears most often in the 5 closest embeddings)\n",
|
"5-nn accuracy: 80.00 % (Picks the class that appears most often in the 5 closest embeddings)\n",
|
||||||
"Top-5 embeddings class match: 15.43 % (Picks any class in the 5 closest embeddings)\n",
|
"Top-5 embeddings class match: 86.67 % (Picks any class in the 5 closest embeddings)\n",
|
||||||
"Top-5 unique class match: 15.56 % (Picks the 5 closest distinct classes)\n",
|
"Top-5 unique class match: 93.33 % (Picks the 5 closest distinct classes)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"################################\n",
|
"################################\n",
|
||||||
"\n"
|
"\n"
|
||||||
@@ -375,13 +441,13 @@
|
|||||||
" sorted_labels = labels[argsort]\n",
|
" sorted_labels = labels[argsort]\n",
|
||||||
" if sorted_labels[0] == true_label:\n",
|
" if sorted_labels[0] == true_label:\n",
|
||||||
" top1 += 1\n",
|
" top1 += 1\n",
|
||||||
" if use_centroids:\n",
|
" # if use_centroids:\n",
|
||||||
" good_samples.append(df_val.loc[i, 'video_id'])\n",
|
" # good_samples.append(df_val.loc[i, 'video_id'])\n",
|
||||||
" else:\n",
|
" # else:\n",
|
||||||
" good_samples.append((df_val.loc[i, 'video_id'],\n",
|
" # good_samples.append((df_val.loc[i, 'video_id'],\n",
|
||||||
" df_train.loc[argsort[0], 'video_id'],\n",
|
" # df_train.loc[argsort[0], 'video_id'],\n",
|
||||||
" i,\n",
|
" # i,\n",
|
||||||
" argsort[0]))\n",
|
" # argsort[0]))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n",
|
" if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n",
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,5 +1,5 @@
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from preprocessing.create_wlasl_landmarks_dataset import parse_create_args, create
|
from preprocessing.create_fingerspelling_dataset import parse_create_args, create
|
||||||
from preprocessing.extract_mediapipe_landmarks import parse_extract_args, extract
|
from preprocessing.extract_mediapipe_landmarks import parse_extract_args, extract
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
174
preprocessing/create_fingerspelling_dataset.py
Normal file
174
preprocessing/create_fingerspelling_dataset.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
import os
|
||||||
|
import os.path as op
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import mediapipe as mp
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from utils import get_logger
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from normalization.blazepose_mapping import map_blazepose_df
|
||||||
|
|
||||||
|
BASE_DATA_FOLDER = 'data/'
|
||||||
|
|
||||||
|
mp_drawing = mp.solutions.drawing_utils
|
||||||
|
mp_drawing_styles = mp.solutions.drawing_styles
|
||||||
|
mp_hands = mp.solutions.hands
|
||||||
|
mp_holistic = mp.solutions.holistic
|
||||||
|
pose_landmarks = mp_holistic.PoseLandmark
|
||||||
|
hand_landmarks = mp_holistic.HandLandmark
|
||||||
|
|
||||||
|
|
||||||
|
def get_landmarks_names():
|
||||||
|
'''
|
||||||
|
Returns landmark names for mediapipe holistic model
|
||||||
|
'''
|
||||||
|
pose_lmks = ','.join([f'{lmk.name.lower()}_x,{lmk.name.lower()}_y' for lmk in pose_landmarks])
|
||||||
|
left_hand_lmks = ','.join([f'left_hand_{lmk.name.lower()}_x,left_hand_{lmk.name.lower()}_y'
|
||||||
|
for lmk in hand_landmarks])
|
||||||
|
right_hand_lmks = ','.join([f'right_hand_{lmk.name.lower()}_x,right_hand_{lmk.name.lower()}_y'
|
||||||
|
for lmk in hand_landmarks])
|
||||||
|
lmks_names = f'{pose_lmks},{left_hand_lmks},{right_hand_lmks}'
|
||||||
|
return lmks_names
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_str(arr, precision=6):
|
||||||
|
if isinstance(arr, np.ndarray):
|
||||||
|
values = []
|
||||||
|
for val in arr:
|
||||||
|
if val == 0:
|
||||||
|
values.append('0')
|
||||||
|
else:
|
||||||
|
values.append(f'{val:.{precision}f}')
|
||||||
|
return f"[{','.join(values)}]"
|
||||||
|
else:
|
||||||
|
return str(arr)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_create_args(parser):
|
||||||
|
parser.add_argument('--landmarks-dataset', '-lmks', required=True,
|
||||||
|
help='Path to folder with landmarks npy files. \
|
||||||
|
You need to run `extract_mediapipe_landmarks.py` script first')
|
||||||
|
parser.add_argument('--dataset-folder', '-df', default='data/wlasl',
|
||||||
|
help='Path to folder where original `WLASL_v0.3.json` and `id_to_label.json` are stored. \
|
||||||
|
Note that final CSV files will be saved in this folder too.')
|
||||||
|
parser.add_argument('--videos-folder', '-videos', default=None,
|
||||||
|
help='Path to folder with videos. If None, then no information of videos (fps, length, \
|
||||||
|
width and height) will be stored in final csv file')
|
||||||
|
parser.add_argument('--num-classes', '-nc', default=100, type=int, help='Number of classes to use in WLASL dataset')
|
||||||
|
parser.add_argument('--create-new-split', action='store_true')
|
||||||
|
parser.add_argument('--test-size', '-ts', default=0.25, type=float,
|
||||||
|
help='Test split percentage size. Only required if --create-new-split is set')
|
||||||
|
|
||||||
|
|
||||||
|
# python3 preprocessing.py --landmarks-dataset=data/landmarks -videos data/wlasl/videos
|
||||||
|
def create(args):
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
landmarks_dataset = args.landmarks_dataset
|
||||||
|
videos_folder = args.videos_folder
|
||||||
|
dataset_folder = args.dataset_folder
|
||||||
|
num_classes = args.num_classes
|
||||||
|
test_size = args.test_size
|
||||||
|
|
||||||
|
os.makedirs(dataset_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/id_to_label.json'), dataset_folder)
|
||||||
|
# shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/WLASL_v0.3.json'), dataset_folder)
|
||||||
|
|
||||||
|
# get files in landmarks_dataset folder
|
||||||
|
landmarks_files = os.listdir(landmarks_dataset)
|
||||||
|
|
||||||
|
video_data = []
|
||||||
|
for i, file in enumerate(tqdm(landmarks_files)):
|
||||||
|
|
||||||
|
# split by !
|
||||||
|
label = file.split('!')[0]
|
||||||
|
subset = file.split('!')[1].split('.')[0]
|
||||||
|
|
||||||
|
# remove npy and set mp4
|
||||||
|
video_id = file.replace('.npy', "")
|
||||||
|
|
||||||
|
|
||||||
|
video_dict = {'video_id': video_id,
|
||||||
|
'label_name': label,
|
||||||
|
'split': subset}
|
||||||
|
|
||||||
|
if videos_folder is not None:
|
||||||
|
cap = cv2.VideoCapture(op.join(videos_folder, f'{video_id}.mp4'))
|
||||||
|
if not cap.isOpened():
|
||||||
|
logger.warning(f'Video {video_id}.mp4 not found')
|
||||||
|
continue
|
||||||
|
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||||
|
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
length = cap.get(cv2.CAP_PROP_FRAME_COUNT) / float(cap.get(cv2.CAP_PROP_FPS))
|
||||||
|
video_info = {'video_width': width,
|
||||||
|
'video_height': height,
|
||||||
|
'fps': fps,
|
||||||
|
'length': length}
|
||||||
|
video_dict.update(video_info)
|
||||||
|
video_data.append(video_dict)
|
||||||
|
|
||||||
|
df_video = pd.DataFrame(video_data)
|
||||||
|
video_ids = df_video['video_id'].unique()
|
||||||
|
lmks_data = []
|
||||||
|
lmks_names = get_landmarks_names().split(',')
|
||||||
|
|
||||||
|
# get labels from df_video
|
||||||
|
labels = df_video['label_name'].unique()
|
||||||
|
# map labels to ids
|
||||||
|
label_to_id = {label: i for i, label in enumerate(labels)}
|
||||||
|
|
||||||
|
# add label_id column to df_video
|
||||||
|
df_video['labels'] = df_video['label_name'].map(label_to_id)
|
||||||
|
|
||||||
|
# export to json file as id to label
|
||||||
|
id_to_label = {i: label for label, i in label_to_id.items()}
|
||||||
|
with open(op.join(dataset_folder, 'id_to_label.json'), 'w') as f:
|
||||||
|
json.dump(id_to_label, f, indent=4)
|
||||||
|
|
||||||
|
for video_id in video_ids:
|
||||||
|
lmk_fn = op.join(landmarks_dataset, f'{video_id}.npy')
|
||||||
|
if not op.exists(lmk_fn):
|
||||||
|
logger.warning(f'{lmk_fn} file not found. Skipping')
|
||||||
|
continue
|
||||||
|
lmk = np.load(lmk_fn).T
|
||||||
|
lmks_dict = {'video_id': video_id}
|
||||||
|
for lmk_, name in zip(lmk, lmks_names):
|
||||||
|
lmks_dict[name] = lmk_
|
||||||
|
lmks_data.append(lmks_dict)
|
||||||
|
|
||||||
|
df_lmks = pd.DataFrame(lmks_data)
|
||||||
|
df = pd.merge(df_video, df_lmks)
|
||||||
|
aux_columns = ['split', 'video_id', 'labels', 'label_name']
|
||||||
|
if videos_folder is not None:
|
||||||
|
aux_columns += ['video_width', 'video_height', 'fps', 'length']
|
||||||
|
df_aux = df[aux_columns]
|
||||||
|
df = map_blazepose_df(df)
|
||||||
|
df = pd.concat([df, df_aux], axis=1)
|
||||||
|
if args.create_new_split:
|
||||||
|
df_train, df_test = train_test_split(df, test_size=test_size, stratify=df['labels'], random_state=42)
|
||||||
|
|
||||||
|
|
||||||
|
print(f'Num classes: {num_classes}')
|
||||||
|
print(df_train['labels'].value_counts())
|
||||||
|
print(df_test['labels'].value_counts())
|
||||||
|
assert set(df_train['labels'].unique()) == set(df_test['labels'].unique(
|
||||||
|
)), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \
|
||||||
|
the --create-new-split flag'
|
||||||
|
for split, df_split in zip(['train', 'val'],
|
||||||
|
[df_train, df_test]):
|
||||||
|
fn_out = op.join(dataset_folder, f'{split}.csv')
|
||||||
|
(df_split.reset_index(drop=True)
|
||||||
|
.applymap(convert_to_str)
|
||||||
|
.to_csv(fn_out, index=False))
|
||||||
|
|
||||||
|
else:
|
||||||
|
fn_out = op.join(dataset_folder, 'train.csv')
|
||||||
|
(df.reset_index(drop=True)
|
||||||
|
.applymap(convert_to_str)
|
||||||
|
.to_csv(fn_out, index=False))
|
||||||
@@ -4,6 +4,8 @@ import pandas as pd
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from normalization.blazepose_mapping import map_blazepose_df
|
||||||
|
|
||||||
def create(train_landmark_files, train_csv, dataset_folder, test_size):
|
def create(train_landmark_files, train_csv, dataset_folder, test_size):
|
||||||
os.makedirs(dataset_folder, exist_ok=True)
|
os.makedirs(dataset_folder, exist_ok=True)
|
||||||
|
|
||||||
@@ -17,15 +19,15 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
|
|||||||
mapping = {
|
mapping = {
|
||||||
'pose_0': 'nose',
|
'pose_0': 'nose',
|
||||||
'pose_1': 'leftEye',
|
'pose_1': 'leftEye',
|
||||||
'pose_2': 'rightEye',
|
'pose_4': 'rightEye',
|
||||||
'pose_3': 'leftEar',
|
'pose_7': 'leftEar',
|
||||||
'pose_4': 'rightEar',
|
'pose_8': 'rightEar',
|
||||||
'pose_5': 'leftShoulder',
|
'pose_11': 'leftShoulder',
|
||||||
'pose_6': 'rightShoulder',
|
'pose_12': 'rightShoulder',
|
||||||
'pose_7': 'leftElbow',
|
'pose_13': 'leftElbow',
|
||||||
'pose_8': 'rightElbow',
|
'pose_14': 'rightElbow',
|
||||||
'pose_9': 'leftWrist',
|
'pose_15': 'leftWrist',
|
||||||
'pose_10': 'rightWrist',
|
'pose_16': 'rightWrist',
|
||||||
|
|
||||||
'left_hand_0': 'wrist_left',
|
'left_hand_0': 'wrist_left',
|
||||||
'left_hand_1': 'thumbCMC_left',
|
'left_hand_1': 'thumbCMC_left',
|
||||||
@@ -77,7 +79,7 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
|
|||||||
columns.append(f'{v}_X')
|
columns.append(f'{v}_X')
|
||||||
columns.append(f'{v}_Y')
|
columns.append(f'{v}_Y')
|
||||||
|
|
||||||
for _, row in tqdm(train_df.head(6000).iterrows(), total=6000):
|
for _, row in tqdm(train_df.head(10000).iterrows(), total=10000):
|
||||||
path, participant_id, sequence_id, sign = row['path'], row['participant_id'], row['sequence_id'], row['sign']
|
path, participant_id, sequence_id, sign = row['path'], row['participant_id'], row['sequence_id'], row['sign']
|
||||||
parquet_file = os.path.join(train_landmark_files, str(participant_id), f"{sequence_id}.parquet")
|
parquet_file = os.path.join(train_landmark_files, str(participant_id), f"{sequence_id}.parquet")
|
||||||
|
|
||||||
@@ -136,6 +138,7 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
|
|||||||
video_data.append(new_landmark_data)
|
video_data.append(new_landmark_data)
|
||||||
|
|
||||||
video_data = pd.concat(video_data, axis=0, ignore_index=True)
|
video_data = pd.concat(video_data, axis=0, ignore_index=True)
|
||||||
|
video_data = map_blazepose_df(video_data, rename=False)
|
||||||
video_data.to_csv(os.path.join(dataset_folder, 'spoter.csv'), index=False)
|
video_data.to_csv(os.path.join(dataset_folder, 'spoter.csv'), index=False)
|
||||||
|
|
||||||
train_landmark_files = 'data/train_landmark_files'
|
train_landmark_files = 'data/train_landmark_files'
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ def create(args):
|
|||||||
'length': length}
|
'length': length}
|
||||||
video_dict.update(video_info)
|
video_dict.update(video_info)
|
||||||
video_data.append(video_dict)
|
video_data.append(video_dict)
|
||||||
|
|
||||||
df_video = pd.DataFrame(video_data)
|
df_video = pd.DataFrame(video_data)
|
||||||
video_ids = df_video['video_id'].unique()
|
video_ids = df_video['video_id'].unique()
|
||||||
lmks_data = []
|
lmks_data = []
|
||||||
@@ -126,9 +127,7 @@ def create(args):
|
|||||||
lmks_data.append(lmks_dict)
|
lmks_data.append(lmks_dict)
|
||||||
|
|
||||||
df_lmks = pd.DataFrame(lmks_data)
|
df_lmks = pd.DataFrame(lmks_data)
|
||||||
print(df_lmks)
|
|
||||||
df = pd.merge(df_video, df_lmks)
|
df = pd.merge(df_video, df_lmks)
|
||||||
print(df)
|
|
||||||
aux_columns = ['split', 'video_id', 'labels', 'label_name']
|
aux_columns = ['split', 'video_id', 'labels', 'label_name']
|
||||||
if videos_folder is not None:
|
if videos_folder is not None:
|
||||||
aux_columns += ['video_width', 'video_height', 'fps', 'length']
|
aux_columns += ['video_width', 'video_height', 'fps', 'length']
|
||||||
|
|||||||
@@ -37,6 +37,10 @@ class LandmarksResults:
|
|||||||
self.num_landmarks_pose = num_landmarks_pose
|
self.num_landmarks_pose = num_landmarks_pose
|
||||||
self.num_landmarks_hand = num_landmarks_hand
|
self.num_landmarks_hand = num_landmarks_hand
|
||||||
|
|
||||||
|
@property
|
||||||
|
def empty(self):
|
||||||
|
return self.results.pose_landmarks is None or (self.results.left_hand_landmarks is None and self.results.right_hand_landmarks is None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pose_landmarks(self):
|
def pose_landmarks(self):
|
||||||
if self.results.pose_landmarks is None:
|
if self.results.pose_landmarks is None:
|
||||||
@@ -67,6 +71,10 @@ def get_landmarks(image_orig, holistic, debug=False):
|
|||||||
# Convert the BGR image to RGB before processing.
|
# Convert the BGR image to RGB before processing.
|
||||||
image = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
|
image = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
|
||||||
results = LandmarksResults(holistic.process(image))
|
results = LandmarksResults(holistic.process(image))
|
||||||
|
|
||||||
|
if results.empty:
|
||||||
|
return None
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
lmks_pose = []
|
lmks_pose = []
|
||||||
for lmk in results.pose_landmarks:
|
for lmk in results.pose_landmarks:
|
||||||
@@ -94,6 +102,7 @@ def get_landmarks(image_orig, holistic, debug=False):
|
|||||||
len(lmks_right_hand) == 2 * LEN_LANDMARKS_HAND
|
len(lmks_right_hand) == 2 * LEN_LANDMARKS_HAND
|
||||||
), f"{len(lmks_right_hand)} != {2 * LEN_LANDMARKS_HAND}"
|
), f"{len(lmks_right_hand)} != {2 * LEN_LANDMARKS_HAND}"
|
||||||
landmarks = []
|
landmarks = []
|
||||||
|
|
||||||
for lmk in chain(
|
for lmk in chain(
|
||||||
results.pose_landmarks,
|
results.pose_landmarks,
|
||||||
results.left_hand_landmarks,
|
results.left_hand_landmarks,
|
||||||
@@ -128,10 +137,21 @@ def extract(args):
|
|||||||
videos_folder = args.videos_folder
|
videos_folder = args.videos_folder
|
||||||
os.makedirs(landmarks_output, exist_ok=True)
|
os.makedirs(landmarks_output, exist_ok=True)
|
||||||
for fn_video in tqdm(sorted(glob.glob(op.join(videos_folder, "*mp4")))):
|
for fn_video in tqdm(sorted(glob.glob(op.join(videos_folder, "*mp4")))):
|
||||||
|
|
||||||
|
# check if landmarks already exist
|
||||||
|
if op.exists(op.join(landmarks_output, op.basename(fn_video).split(".")[0] + ".npy")):
|
||||||
|
continue
|
||||||
|
|
||||||
cap = cv2.VideoCapture(fn_video)
|
cap = cv2.VideoCapture(fn_video)
|
||||||
ret, image_orig = cap.read()
|
ret, image_orig = cap.read()
|
||||||
height, width = image_orig.shape[:2]
|
height, width = image_orig.shape[:2]
|
||||||
landmarks_video = []
|
landmarks_video = []
|
||||||
|
|
||||||
|
# make sure fps is 20 by determining the number of frames to be skipped
|
||||||
|
frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
|
||||||
|
frame_skip = (frame_rate // 10) - 1
|
||||||
|
|
||||||
|
|
||||||
with tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))) as pbar:
|
with tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))) as pbar:
|
||||||
with mp_holistic.Holistic(
|
with mp_holistic.Holistic(
|
||||||
static_image_mode=False,
|
static_image_mode=False,
|
||||||
@@ -145,6 +165,10 @@ def extract(args):
|
|||||||
print(e)
|
print(e)
|
||||||
landmarks = get_landmarks(image_orig, holistic, debug=True)
|
landmarks = get_landmarks(image_orig, holistic, debug=True)
|
||||||
ret, image_orig = cap.read()
|
ret, image_orig = cap.read()
|
||||||
|
for _ in range(frame_skip):
|
||||||
|
ret, image_orig = cap.read()
|
||||||
|
pbar.update(1)
|
||||||
|
if landmarks:
|
||||||
landmarks_video.append(landmarks)
|
landmarks_video.append(landmarks)
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
landmarks_video = np.vstack(landmarks_video)
|
landmarks_video = np.vstack(landmarks_video)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ dataset = "data/processed/spoter.csv"
|
|||||||
|
|
||||||
# read the dataset
|
# read the dataset
|
||||||
df = pd.read_csv(dataset)
|
df = pd.read_csv(dataset)
|
||||||
df = map_blazepose_df(df)
|
|
||||||
|
|
||||||
with open("data/sign_to_prediction_index_map.json", "r") as f:
|
with open("data/sign_to_prediction_index_map.json", "r") as f:
|
||||||
sign_to_prediction_index_max = json.load(f)
|
sign_to_prediction_index_max = json.load(f)
|
||||||
@@ -17,6 +16,9 @@ with open("data/sign_to_prediction_index_map.json", "r") as f:
|
|||||||
# filter df to make sure each sign has at least 4 samples
|
# filter df to make sure each sign has at least 4 samples
|
||||||
df = df[df["sign"].map(df["sign"].value_counts()) > 4]
|
df = df[df["sign"].map(df["sign"].value_counts()) > 4]
|
||||||
|
|
||||||
|
# print number of unique signs
|
||||||
|
print("Number of unique signs: ", len(df["sign"].unique()))
|
||||||
|
|
||||||
# use the path column to split the dataset
|
# use the path column to split the dataset
|
||||||
paths = df["path"].unique()
|
paths = df["path"].unique()
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
pandas
|
pandas
|
||||||
bokeh==2.4.3
|
bokeh==2.4.3
|
||||||
boto3>=1.9
|
boto3>=1.9
|
||||||
clearml==1.6.4
|
|
||||||
ipywidgets==8.0.4
|
ipywidgets==8.0.4
|
||||||
matplotlib==3.5.3
|
matplotlib==3.5.3
|
||||||
mediapipe==0.8.11
|
mediapipe==0.8.11
|
||||||
@@ -9,6 +8,9 @@ notebook==6.5.2
|
|||||||
opencv-python==4.6.0.66
|
opencv-python==4.6.0.66
|
||||||
plotly==5.11.0
|
plotly==5.11.0
|
||||||
scikit-learn==1.0.2
|
scikit-learn==1.0.2
|
||||||
torch
|
clearml==1.10.3
|
||||||
torchvision
|
torch==2.0.0
|
||||||
|
torchvision==0.15.1
|
||||||
tqdm==4.54.1
|
tqdm==4.54.1
|
||||||
|
optuna==3.1.1
|
||||||
|
onnx==1.14.0
|
||||||
22
train.py
22
train.py
@@ -15,7 +15,7 @@ from torchvision import transforms
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import copy
|
import copy
|
||||||
|
import numpy as np
|
||||||
from datasets import CzechSLRDataset, SLREmbeddingDataset, collate_fn_triplet_padd, collate_fn_padd
|
from datasets import CzechSLRDataset, SLREmbeddingDataset, collate_fn_triplet_padd, collate_fn_padd
|
||||||
from models import SPOTER, SPOTER_EMBEDDINGS, train_epoch, evaluate, train_epoch_embedding, \
|
from models import SPOTER, SPOTER_EMBEDDINGS, train_epoch, evaluate, train_epoch_embedding, \
|
||||||
train_epoch_embedding_online, evaluate_embedding
|
train_epoch_embedding_online, evaluate_embedding
|
||||||
@@ -32,7 +32,7 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
PROJECT_NAME = "spoter"
|
PROJECT_NAME = "SpoterEmbedding"
|
||||||
CLEARML = "clearml"
|
CLEARML = "clearml"
|
||||||
|
|
||||||
|
|
||||||
@@ -75,12 +75,25 @@ def train(args, tracker: Tracker):
|
|||||||
|
|
||||||
# Construct the model
|
# Construct the model
|
||||||
if not args.classification_model:
|
if not args.classification_model:
|
||||||
|
# if finetune, load the weights from the classification model
|
||||||
|
if args.finetune:
|
||||||
|
checkpoint = torch.load(args.checkpoint_path, map_location=device)
|
||||||
|
|
||||||
|
slrt_model = SPOTER_EMBEDDINGS(
|
||||||
|
features=checkpoint["config_args"].vector_length,
|
||||||
|
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||||
|
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||||
|
dropout=checkpoint["config_args"].dropout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
slrt_model = SPOTER_EMBEDDINGS(
|
slrt_model = SPOTER_EMBEDDINGS(
|
||||||
features=args.vector_length,
|
features=args.vector_length,
|
||||||
hidden_dim=args.hidden_dim,
|
hidden_dim=args.hidden_dim,
|
||||||
norm_emb=args.normalize_embeddings,
|
norm_emb=args.normalize_embeddings,
|
||||||
dropout=args.dropout
|
dropout=args.dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
model_type = 'embed'
|
model_type = 'embed'
|
||||||
if args.hard_triplet_mining == "None":
|
if args.hard_triplet_mining == "None":
|
||||||
cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2)
|
cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2)
|
||||||
@@ -233,6 +246,9 @@ def train(args, tracker: Tracker):
|
|||||||
val_accs.append(val_acc)
|
val_accs.append(val_acc)
|
||||||
tracker.log_scalar_metric("acc", "val", epoch, val_acc)
|
tracker.log_scalar_metric("acc", "val", epoch, val_acc)
|
||||||
|
|
||||||
|
create_embedding_scatter_plots(tracker, slrt_model, train_loader, val_loader, device, id_to_label, epoch,
|
||||||
|
top_model_name)
|
||||||
|
|
||||||
logger.info(f"Epoch time: {datetime.now() - start_time}")
|
logger.info(f"Epoch time: {datetime.now() - start_time}")
|
||||||
logger.info("[" + str(epoch) + "] TRAIN loss: " + str(train_loss) + " acc: " + str(train_accs[-1]))
|
logger.info("[" + str(epoch) + "] TRAIN loss: " + str(train_loss) + " acc: " + str(train_accs[-1]))
|
||||||
logger.info("[" + str(epoch) + "] VALIDATION acc: " + str(val_accs[-1]))
|
logger.info("[" + str(epoch) + "] VALIDATION acc: " + str(val_accs[-1]))
|
||||||
|
|||||||
28
train.sh
28
train.sh
@@ -1,22 +1,24 @@
|
|||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
python -m train \
|
python3 -m train \
|
||||||
--save_checkpoints_every 10 \
|
--save_checkpoints_every 10 \
|
||||||
--experiment_name "augment_rotate_75_x8" \
|
--experiment_name "Finetune Fingerspelling Signs" \
|
||||||
--epochs 300 \
|
--epochs 1000 \
|
||||||
--optimizer "ADAM" \
|
--optimizer "ADAM" \
|
||||||
--lr 0.001 \
|
--lr 0.00001 \
|
||||||
--batch_size 16 \
|
--batch_size 8 \
|
||||||
--dataset_name "processed" \
|
--dataset_name "FingerSpelling" \
|
||||||
--training_set_path "spoter_train.csv" \
|
--training_set_path "train.csv" \
|
||||||
--validation_set_path "spoter_test.csv" \
|
--validation_set_path "val.csv" \
|
||||||
--vector_length 32 \
|
--vector_length 32 \
|
||||||
--epoch_iters -1 \
|
--epoch_iters -1 \
|
||||||
--scheduler_factor 0 \
|
--scheduler_factor 0 \
|
||||||
--hard_triplet_mining "in_batch" \
|
--hard_triplet_mining "in_batch" \
|
||||||
--filter_easy_triplets \
|
--filter_easy_triplets \
|
||||||
--triplet_loss_margin 1 \
|
--start_mining_hard 50 \
|
||||||
|
--triplet_loss_margin 4 \
|
||||||
--dropout 0.2 \
|
--dropout 0.2 \
|
||||||
--augmentations_prob=0.75 \
|
--tracker=clearml \
|
||||||
--hard_mining_scheduler_triplets_threshold=0 \
|
--dataset_loader=clearml \
|
||||||
--normalize_embeddings \
|
--dataset_project="SpoterEmbedding" \
|
||||||
--num_classes 100 \
|
--finetune \
|
||||||
|
--checkpoint_path "checkpoints/checkpoint_embed_3835.pth"
|
||||||
@@ -81,4 +81,7 @@ def get_default_args():
|
|||||||
help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \
|
help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \
|
||||||
distance threshold of the batch sorter")
|
distance threshold of the batch sorter")
|
||||||
|
|
||||||
|
parser.add_argument("--finetune", action='store_true', default=False, help="Fintune the model")
|
||||||
|
parser.add_argument("--checkpoint_path", type=str, default="")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|||||||
1636
visualize_data.ipynb
Normal file
1636
visualize_data.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -238,6 +238,7 @@ def distance_matrix(keypoints, embeddings, p=2, threshold=1000000):
|
|||||||
f"{kk}-dimensional vectors")
|
f"{kk}-dimensional vectors")
|
||||||
|
|
||||||
if m*n*k <= threshold:
|
if m*n*k <= threshold:
|
||||||
|
print("Using minkowski_distance")
|
||||||
return minkowski_distance(x[:,np.newaxis,:],y[np.newaxis,:,:],p)
|
return minkowski_distance(x[:,np.newaxis,:],y[np.newaxis,:,:],p)
|
||||||
else:
|
else:
|
||||||
result = np.empty((m,n),dtype=float) # FIXME: figure out the best dtype
|
result = np.empty((m,n),dtype=float) # FIXME: figure out the best dtype
|
||||||
|
|||||||
Reference in New Issue
Block a user