Merge branch 'FingerspellingEmbedding-+-ClearML' into 'main'

Fingerspelling embedding + ClearML

See merge request wesign/spoterembedding!1
This commit is contained in:
Victor Mylle 2023-05-21 20:30:13 +00:00
commit 36d09ae49f
26 changed files with 2465 additions and 176 deletions

1
.gitignore vendored
View File

@ -155,3 +155,4 @@ out-img/
converted_models/
*.pth
*.onnx
.devcontainer

20
README2.md Normal file
View File

@ -0,0 +1,20 @@
# Spoter Embeddings
## Creating dataset
First, make a folder where all you're videos are located. When this is done, all keypoints can be extracted from the videos using the following command. This will extract the keypoints and store them in \<path-to-landmarks-folder\>.
```
python3 preprocessing.py extract --videos-folder <path-to-videos-folder> --output-landmark <path-to-landmarks-folder>
```
When this is done, the dataset can be created using the following command:
```
python3 preprocessing.py create --landmarks-dataset <path-to-landmarks-folder> --videos-folder <path-to-videos-folder> --dataset-folder <dataset-output-folder> (--create-new-split --test-size <test-percentage>)
```
The above command generates a train (and val) csv file which includes all the extracted keypoints. These can then be used to train or generates embeddings.
## Creating Embeddings
The embeddings can be created using the following command:
```
python3 export_embeddings.py --checkpoint <path-to-checkpoints-file> --dataset <path-to-dataset-file> --output <embeddings-output-file>
```
The command above generates the embeddings for a given dataset and saves them as a csv file.

Binary file not shown.

Binary file not shown.

View File

@ -17,5 +17,4 @@ requests==2.28.1
onnx==1.12.0
onnx-tf==1.10.0
onnxruntime==1.12.1
tensorflow
tensorflow-probability
coremltools==6.3.0

View File

@ -0,0 +1,91 @@
embeddings,label_name,labels,embeddings2
"[[ 0.01295076 -0.1074132 0.5111911 0.08452003 -0.7638924 0.6458221
-1.3892679 -1.1791427 -0.30736607 -0.41543546 -0.6358013 0.31411174
0.878936 1.7265923 -0.09298562 0.12005516 0.26995468 -1.5934305
-0.12619524 0.9111336 0.91827893 0.5948979 0.90334046 -0.84089845
-0.29575598 -0.3024254 -0.03074584 1.4402957 -0.22309914 0.54750854
-0.68767273 -0.0665718 ]]",BEDANKEN,0,"[0.012950755655765533, -0.10741320252418518, 0.5111911296844482, 0.08452003449201584, -0.763892412185669, 0.6458221077919006, -1.389267921447754, -1.179142713546753, -0.3073660731315613, -0.41543546319007874, -0.6358013153076172, 0.31411173939704895, 0.8789359927177429, 1.7265923023223877, -0.0929856151342392, 0.12005516141653061, 0.2699546813964844, -1.593430519104004, -0.12619523704051971, 0.9111335873603821, 0.9182789325714111, 0.5948979258537292, 0.9033404588699341, -0.8408984541893005, -0.2957559823989868, -0.30242541432380676, -0.030745841562747955, 1.440295696258545, -0.22309914231300354, 0.5475085377693176, -0.6876727342605591, -0.06657180190086365]"
"[[ 1.0287632 -0.94657993 0.7852507 -0.37169546 0.3614158 0.5360078
-0.9989285 -0.24402924 -0.47437504 -0.17307773 0.9196662 0.5362411
-0.50366765 0.7600024 0.6701926 -0.5288625 -0.38939306 -1.0722619
-1.33836 1.331829 -0.6958481 0.57767504 0.87002045 -0.59203666
-0.57576096 0.70008135 0.8224436 -0.32370126 0.5276206 -0.62150276
-0.9639951 -0.40879178]]",GOED,1,"[1.0287631750106812, -0.9465799331665039, 0.785250723361969, -0.37169545888900757, 0.3614158034324646, 0.536007821559906, -0.9989284873008728, -0.244029238820076, -0.47437503933906555, -0.17307773232460022, 0.9196661710739136, 0.5362411141395569, -0.5036676526069641, 0.7600023746490479, 0.6701925992965698, -0.528862476348877, -0.38939306139945984, -1.072261929512024, -1.3383599519729614, 1.3318289518356323, -0.6958481073379517, 0.5776750445365906, 0.8700204491615295, -0.5920366644859314, -0.5757609605789185, 0.7000813484191895, 0.8224436044692993, -0.32370126247406006, 0.5276206135749817, -0.6215027570724487, -0.963995099067688, -0.40879178047180176]"
"[[-0.5988377 -1.6454532 0.5696761 0.27801913 -0.39687645 1.5192578
-0.57908726 -1.9851501 1.1918951 -1.3280408 0.27390718 0.71642965
-1.0471189 0.9825219 -0.6625729 -0.28641602 2.0109527 -0.97667754
-2.3978348 2.666861 0.15066166 -1.6922158 0.9681514 -1.1413386
0.08027886 -0.6215762 -0.09759905 0.6157277 -0.01341684 0.9806912
-0.09793527 -1.9003383 ]]",GOEDEMIDDAG,2,"[-0.598837673664093, -1.6454532146453857, 0.5696761012077332, 0.27801913022994995, -0.3968764543533325, 1.5192577838897705, -0.5790872573852539, -1.9851500988006592, 1.1918951272964478, -1.3280408382415771, 0.2739071846008301, 0.7164296507835388, -1.047118902206421, 0.9825218915939331, -0.6625729203224182, -0.28641602396965027, 2.0109527111053467, -0.9766775369644165, -2.3978347778320312, 2.666861057281494, 0.1506616622209549, -1.6922158002853394, 0.9681513905525208, -1.141338586807251, 0.08027885854244232, -0.621576189994812, -0.09759905189275742, 0.6157277226448059, -0.013416841626167297, 0.9806911945343018, -0.0979352742433548, -1.9003382921218872]"
"[[-0.11578767 0.28371835 0.8604608 -0.31061298 -0.78153455 -0.24914108
-2.443886 -0.3267159 0.29237527 -0.7879326 -0.12082034 0.32809633
-1.0938975 0.7125962 0.49117383 -0.3119098 0.15369919 0.39569452
-1.18983 1.2498406 -0.38752007 -0.54896545 -0.620718 0.02654967
-0.10994279 -0.39713857 -0.18035033 -1.554728 0.12178666 -0.5698373
1.1097438 -1.2350351 ]]",GOEDEMORGEN,3,"[-0.11578767001628876, 0.2837183475494385, 0.8604608178138733, -0.3106129765510559, -0.7815345525741577, -0.24914108216762543, -2.4438860416412354, -0.326715886592865, 0.29237526655197144, -0.7879325747489929, -0.12082034349441528, 0.32809633016586304, -1.0938974618911743, 0.7125961780548096, 0.4911738336086273, -0.3119097948074341, 0.15369918942451477, 0.3956945240497589, -1.18982994556427, 1.2498406171798706, -0.38752007484436035, -0.5489654541015625, -0.6207180023193359, 0.02654966711997986, -0.10994279384613037, -0.3971385657787323, -0.18035033345222473, -1.5547280311584473, 0.12178666144609451, -0.5698372721672058, 1.1097438335418701, -1.2350350618362427]"
"[[-0.9557147 0.03368077 1.0923759 -0.41077584 -1.6992741 -0.59566665
-1.1562454 0.8824008 -0.25122 0.625515 1.6720039 0.5639413
-0.35894975 -0.05896414 0.84854823 0.4072139 -1.1763096 -1.5435666
-1.5066593 0.66299886 -1.3121625 1.4410669 -0.6341419 -0.4873513
0.8113033 -0.01502099 0.32262453 2.5855515 -0.7294207 0.7582124
-0.1386989 -0.17011487]]",GOEDENACHT,4,"[-0.9557147026062012, 0.03368076682090759, 1.0923758745193481, -0.4107758402824402, -1.6992740631103516, -0.5956666469573975, -1.1562453508377075, 0.8824008107185364, -0.2512199878692627, 0.6255149841308594, 1.6720038652420044, 0.5639412999153137, -0.35894975066185, -0.058964136987924576, 0.8485482335090637, 0.40721389651298523, -1.176309585571289, -1.5435665845870972, -1.5066592693328857, 0.6629988551139832, -1.3121625185012817, 1.441066861152649, -0.6341419219970703, -0.48735129833221436, 0.8113033175468445, -0.015020990744233131, 0.322624534368515, 2.5855515003204346, -0.7294207215309143, 0.7582123875617981, -0.13869890570640564, -0.17011487483978271]"
"[[ 0.13687058 -1.099074 1.1867181 -1.2216619 -0.965039 -0.6957289
-1.6400954 0.7317837 0.5887569 0.18756844 1.4867041 0.63357025
-1.5853169 0.4157976 0.76919526 0.08512082 -1.0937556 -0.07339763
-2.8580015 1.6835626 -1.4538742 1.2302468 -0.49323368 -0.81243044
-0.11378415 -0.09592562 0.31165385 -0.32617894 0.02981896 -0.01485211
1.1880591 -0.40726334]]",GOEDENAVOND,5,"[0.1368705779314041, -1.0990740060806274, 1.1867181062698364, -1.221661925315857, -0.9650390148162842, -0.6957288980484009, -1.6400953531265259, 0.7317836880683899, 0.5887569189071655, 0.18756844103336334, 1.4867041110992432, 0.6335702538490295, -1.5853168964385986, 0.4157975912094116, 0.7691952586174011, 0.08512081950902939, -1.0937556028366089, -0.07339762896299362, -2.858001470565796, 1.6835626363754272, -1.4538742303848267, 1.2302467823028564, -0.49323368072509766, -0.8124304413795471, -0.11378414928913116, -0.09592562168836594, 0.31165385246276855, -0.3261789381504059, 0.029818959534168243, -0.014852114021778107, 1.1880590915679932, -0.4072633385658264]"
"[[ 1.3673804 1.7083405 0.58775777 0.26056585 -0.29101595 -0.69954723
0.45681733 1.6908163 0.02684825 0.7537957 2.2054908 0.28831166
-1.5712252 -1.6869793 0.8009781 -0.51264 -1.3544708 -0.72502786
-0.31009498 -0.23777717 -1.7524906 1.6333568 1.5263942 0.22317217
-0.40576094 2.2136266 1.4030526 -1.0599502 0.8686069 -1.1141618
-0.01899157 -1.2256656 ]]",JA,6,"[1.3673803806304932, 1.7083405256271362, 0.5877577662467957, 0.260565847158432, -0.29101595282554626, -0.6995472311973572, 0.4568173289299011, 1.6908162832260132, 0.02684824913740158, 0.7537956833839417, 2.205490827560425, 0.2883116602897644, -1.5712251663208008, -1.6869792938232422, 0.8009781241416931, -0.5126399993896484, -1.3544708490371704, -0.725027859210968, -0.3100949823856354, -0.23777717351913452, -1.7524906396865845, 1.6333568096160889, 1.526394248008728, 0.2231721729040146, -0.4057609438896179, 2.2136266231536865, 1.403052568435669, -1.0599502325057983, 0.8686069250106812, -1.1141618490219116, -0.01899157091975212, -1.22566556930542]"
"[[ 1.8208723 -0.59625745 0.15060106 -0.23053254 1.0344827 1.7016335
0.1968202 -0.48188782 0.3004466 -0.53336793 0.89704275 0.26869062
-1.0112952 0.0343281 0.32149622 -0.8189871 0.33571526 -0.57059044
-0.7577773 1.4743904 -0.01735806 -0.4116615 2.3637483 -0.6602343
-0.6101709 1.8139238 1.041902 -0.9006022 0.6082014 -0.23616463
-0.62204087 -0.524899 ]]",LINKS,7,"[1.8208723068237305, -0.5962574481964111, 0.15060105919837952, -0.23053254187107086, 1.034482717514038, 1.7016334533691406, 0.1968201994895935, -0.4818878173828125, 0.30044659972190857, -0.533367931842804, 0.8970427513122559, 0.2686906158924103, -1.011295199394226, 0.034328099340200424, 0.32149621844291687, -0.8189870715141296, 0.33571526408195496, -0.5705904364585876, -0.7577772736549377, 1.4743903875350952, -0.01735806092619896, -0.4116615056991577, 2.36374831199646, -0.660234272480011, -0.6101709008216858, 1.8139238357543945, 1.04190194606781, -0.9006022214889526, 0.6082013845443726, -0.2361646294593811, -0.622040867805481, -0.5248990058898926]"
"[[ 1.3968729 2.340378 0.567814 0.5684975 -0.8795973 -1.0090083
-0.01301649 1.9747832 -0.3257468 0.76960254 1.402801 0.39424965
0.06948688 -0.51904166 1.0961802 -0.34335235 -1.8730735 -1.0801809
0.2751789 -0.96998405 -1.9822828 2.1046805 1.3094746 0.5770473
-0.19693638 1.8973799 0.99536693 0.2668684 0.28499565 -0.9838486
-0.19578832 -0.46982569]]",NEE,8,"[1.396872878074646, 2.3403780460357666, 0.5678139925003052, 0.5684974789619446, -0.8795973062515259, -1.0090082883834839, -0.013016488403081894, 1.974783182144165, -0.3257468044757843, 0.7696025371551514, 1.4028010368347168, 0.39424964785575867, 0.06948687881231308, -0.5190416574478149, 1.0961802005767822, -0.343352347612381, -1.8730734586715698, -1.0801808834075928, 0.2751789093017578, -0.9699840545654297, -1.9822827577590942, 2.1046805381774902, 1.3094745874404907, 0.5770472884178162, -0.19693638384342194, 1.8973798751831055, 0.9953669309616089, 0.2668684124946594, 0.2849956452846527, -0.9838485717773438, -0.19578832387924194, -0.4698256850242615]"
"[[ 1.8143647 -0.9315294 0.28706568 -1.0938002 -0.2451038 0.42331484
-0.31653756 0.00268575 0.89822304 1.0639151 1.7406261 1.1707302
-1.6639041 0.30558044 0.22984704 -1.0260515 -0.08818819 -0.14696571
-1.5269523 0.55015475 -0.12936743 0.93951946 1.3917109 -0.62517196
-1.3869016 1.4833041 1.4647726 -1.0285482 -0.3704591 1.4672718
0.40315557 0.04754414]]",RECHTS,9,"[1.8143646717071533, -0.9315294027328491, 0.28706568479537964, -1.0938001871109009, -0.24510380625724792, 0.4233148396015167, -0.3165375590324402, 0.002685748040676117, 0.8982230424880981, 1.0639151334762573, 1.7406260967254639, 1.1707302331924438, -1.663904070854187, 0.3055804371833801, 0.2298470437526703, -1.0260515213012695, -0.08818819373846054, -0.14696571230888367, -1.5269522666931152, 0.5501547455787659, -0.1293674260377884, 0.939519464969635, 1.391710877418518, -0.625171959400177, -1.386901617050171, 1.4833041429519653, 1.4647725820541382, -1.028548240661621, -0.37045910954475403, 1.4672718048095703, 0.4031555652618408, 0.04754413664340973]"
"[[ 1.681777 -0.79441184 0.07110912 -0.01304399 1.110081 2.1115103
-0.27283993 -1.5035654 0.55927217 -0.9903696 0.05718866 0.19891238
-0.7289791 0.71738267 -0.41428873 -0.58389515 1.1087953 -0.6616588
-0.6330258 1.5830203 0.60013187 -1.0640136 2.3917787 -0.70012283
-0.7359694 0.9847692 0.94458187 -0.5915589 0.3662199 0.39359796
-0.6499281 -0.30565363]]",SALUUT,10,"[1.681777000427246, -0.794411838054657, 0.07110912352800369, -0.013043994084000587, 1.1100809574127197, 2.1115102767944336, -0.2728399336338043, -1.5035654306411743, 0.5592721700668335, -0.9903696179389954, 0.05718865618109703, 0.1989123821258545, -0.7289791107177734, 0.7173826694488525, -0.414288729429245, -0.5838951468467712, 1.1087952852249146, -0.6616588234901428, -0.6330258250236511, 1.5830203294754028, 0.6001318693161011, -1.0640136003494263, 2.3917787075042725, -0.7001228332519531, -0.7359694242477417, 0.9847692251205444, 0.9445818662643433, -0.5915588736534119, 0.3662199079990387, 0.39359796047210693, -0.649928092956543, -0.3056536316871643]"
"[[ 0.56773967 -0.6260798 0.23771092 -0.10016712 0.86817515 0.92371875
-0.16313902 -0.3785063 -0.41955388 -0.586288 -0.33712262 -0.07999519
-0.98128295 0.18787864 -0.36432508 -0.06605241 0.00710656 -0.36359936
-0.00642961 1.257384 0.64647824 -1.1618687 1.3707273 -0.46546304
-0.14522405 0.29306158 0.41898865 -0.12311607 0.24604419 -0.48434448
-0.79076517 0.753163 ]]",SLECHT,11,"[0.5677396655082703, -0.626079797744751, 0.23771092295646667, -0.10016711801290512, 0.8681751489639282, 0.9237187504768372, -0.16313901543617249, -0.37850630283355713, -0.41955387592315674, -0.5862879753112793, -0.3371226191520691, -0.07999519258737564, -0.9812829494476318, 0.18787863850593567, -0.36432507634162903, -0.06605241447687149, 0.007106557488441467, -0.36359935998916626, -0.006429608445614576, 1.257383942604065, 0.6464782357215881, -1.161868691444397, 1.370727300643921, -0.4654630422592163, -0.14522404968738556, 0.29306158423423767, 0.4189886450767517, -0.12311606854200363, 0.24604418873786926, -0.484344482421875, -0.7907651662826538, 0.7531629800796509]"
"[[ 0.9251696 -2.8536968 0.6521715 -1.8256427 0.44136596 0.33299872
-1.4674346 -0.7897836 0.47153538 0.1437631 0.7396151 1.0851275
-1.2100328 1.626544 0.69652647 0.0048107 0.13927785 -0.63199776
-3.0121155 2.1738136 -0.7935879 0.2744053 0.281097 -1.0899199
-0.7062637 -0.04824958 0.3436054 -0.69366074 0.22799876 0.8932708
0.51823634 -0.08065546]]",SMAKELIJK,12,"[0.9251695871353149, -2.853696823120117, 0.6521714925765991, -1.825642704963684, 0.44136595726013184, 0.33299872279167175, -1.4674346446990967, -0.7897835969924927, 0.4715353846549988, 0.14376309514045715, 0.7396150827407837, 1.0851274728775024, -1.2100328207015991, 1.6265439987182617, 0.6965264678001404, 0.004810698330402374, 0.139277845621109, -0.6319977641105652, -3.012115478515625, 2.173813581466675, -0.7935879230499268, 0.274405300617218, 0.2810969948768616, -1.089919924736023, -0.7062637209892273, -0.04824957996606827, 0.3436053991317749, -0.6936607360839844, 0.2279987633228302, 0.8932707905769348, 0.5182363390922546, -0.08065545558929443]"
"[[-1.460053 2.9393604 0.5533726 1.3786266 -1.5367306 -1.2867874
-1.2442408 0.62938243 -0.7092637 -1.0278705 -1.3296479 -0.8105969
0.69997054 0.16618343 0.5113248 0.15563956 -0.4815752 -0.10130378
1.4232539 -1.2767068 -0.3229611 0.45140803 -1.3901687 1.0403453
1.454544 -1.308266 -1.145199 0.661388 -0.20519465 -1.0480955
-0.04290716 -0.36275834]]",SORRY,13,"[-1.4600529670715332, 2.9393603801727295, 0.5533726215362549, 1.3786265850067139, -1.5367306470870972, -1.2867873907089233, -1.2442407608032227, 0.6293824315071106, -0.7092636823654175, -1.027870535850525, -1.3296478986740112, -0.8105968832969666, 0.699970543384552, 0.16618342697620392, 0.5113248229026794, 0.15563955903053284, -0.4815751910209656, -0.10130377858877182, 1.4232538938522339, -1.2767068147659302, -0.32296109199523926, 0.4514080286026001, -1.3901686668395996, 1.040345311164856, 1.454543948173523, -1.308266043663025, -1.145198941230774, 0.6613879799842834, -0.2051946520805359, -1.048095464706421, -0.04290715977549553, -0.3627583384513855]"
"[[ 1.0733138 -0.66348886 0.36737278 0.01765811 0.5730918 1.7617474
-0.45580968 -1.0973133 0.4730748 -0.4665998 0.48594972 0.548745
-0.91712314 0.4846773 0.3774526 -0.5996147 0.28750947 -1.0166394
-0.6239797 1.7935088 -0.03666669 -0.51941574 2.045076 -0.9045104
-0.6312211 0.9698636 1.0522215 -0.14772844 0.7406032 0.2901447
-0.48710442 -0.4619298 ]]",TOT-ZIENS,14,"[1.07331383228302, -0.6634888648986816, 0.3673727810382843, 0.017658105120062828, 0.5730918049812317, 1.7617473602294922, -0.45580968260765076, -1.0973132848739624, 0.4730747938156128, -0.46659979224205017, 0.48594972491264343, 0.5487449765205383, -0.9171231389045715, 0.4846773147583008, 0.3774526119232178, -0.599614679813385, 0.2875094711780548, -1.0166393518447876, -0.6239796876907349, 1.793508768081665, -0.036666687577962875, -0.5194157361984253, 2.0450758934020996, -0.9045103788375854, -0.6312211155891418, 0.9698635935783386, 1.0522215366363525, -0.14772844314575195, 0.7406032085418701, 0.2901447117328644, -0.4871044158935547, -0.4619297981262207]"
1 embeddings label_name labels embeddings2
2 [[ 0.01295076 -0.1074132 0.5111911 0.08452003 -0.7638924 0.6458221 -1.3892679 -1.1791427 -0.30736607 -0.41543546 -0.6358013 0.31411174 0.878936 1.7265923 -0.09298562 0.12005516 0.26995468 -1.5934305 -0.12619524 0.9111336 0.91827893 0.5948979 0.90334046 -0.84089845 -0.29575598 -0.3024254 -0.03074584 1.4402957 -0.22309914 0.54750854 -0.68767273 -0.0665718 ]] BEDANKEN 0 [0.012950755655765533, -0.10741320252418518, 0.5111911296844482, 0.08452003449201584, -0.763892412185669, 0.6458221077919006, -1.389267921447754, -1.179142713546753, -0.3073660731315613, -0.41543546319007874, -0.6358013153076172, 0.31411173939704895, 0.8789359927177429, 1.7265923023223877, -0.0929856151342392, 0.12005516141653061, 0.2699546813964844, -1.593430519104004, -0.12619523704051971, 0.9111335873603821, 0.9182789325714111, 0.5948979258537292, 0.9033404588699341, -0.8408984541893005, -0.2957559823989868, -0.30242541432380676, -0.030745841562747955, 1.440295696258545, -0.22309914231300354, 0.5475085377693176, -0.6876727342605591, -0.06657180190086365]
3 [[ 1.0287632 -0.94657993 0.7852507 -0.37169546 0.3614158 0.5360078 -0.9989285 -0.24402924 -0.47437504 -0.17307773 0.9196662 0.5362411 -0.50366765 0.7600024 0.6701926 -0.5288625 -0.38939306 -1.0722619 -1.33836 1.331829 -0.6958481 0.57767504 0.87002045 -0.59203666 -0.57576096 0.70008135 0.8224436 -0.32370126 0.5276206 -0.62150276 -0.9639951 -0.40879178]] GOED 1 [1.0287631750106812, -0.9465799331665039, 0.785250723361969, -0.37169545888900757, 0.3614158034324646, 0.536007821559906, -0.9989284873008728, -0.244029238820076, -0.47437503933906555, -0.17307773232460022, 0.9196661710739136, 0.5362411141395569, -0.5036676526069641, 0.7600023746490479, 0.6701925992965698, -0.528862476348877, -0.38939306139945984, -1.072261929512024, -1.3383599519729614, 1.3318289518356323, -0.6958481073379517, 0.5776750445365906, 0.8700204491615295, -0.5920366644859314, -0.5757609605789185, 0.7000813484191895, 0.8224436044692993, -0.32370126247406006, 0.5276206135749817, -0.6215027570724487, -0.963995099067688, -0.40879178047180176]
4 [[-0.5988377 -1.6454532 0.5696761 0.27801913 -0.39687645 1.5192578 -0.57908726 -1.9851501 1.1918951 -1.3280408 0.27390718 0.71642965 -1.0471189 0.9825219 -0.6625729 -0.28641602 2.0109527 -0.97667754 -2.3978348 2.666861 0.15066166 -1.6922158 0.9681514 -1.1413386 0.08027886 -0.6215762 -0.09759905 0.6157277 -0.01341684 0.9806912 -0.09793527 -1.9003383 ]] GOEDEMIDDAG 2 [-0.598837673664093, -1.6454532146453857, 0.5696761012077332, 0.27801913022994995, -0.3968764543533325, 1.5192577838897705, -0.5790872573852539, -1.9851500988006592, 1.1918951272964478, -1.3280408382415771, 0.2739071846008301, 0.7164296507835388, -1.047118902206421, 0.9825218915939331, -0.6625729203224182, -0.28641602396965027, 2.0109527111053467, -0.9766775369644165, -2.3978347778320312, 2.666861057281494, 0.1506616622209549, -1.6922158002853394, 0.9681513905525208, -1.141338586807251, 0.08027885854244232, -0.621576189994812, -0.09759905189275742, 0.6157277226448059, -0.013416841626167297, 0.9806911945343018, -0.0979352742433548, -1.9003382921218872]
5 [[-0.11578767 0.28371835 0.8604608 -0.31061298 -0.78153455 -0.24914108 -2.443886 -0.3267159 0.29237527 -0.7879326 -0.12082034 0.32809633 -1.0938975 0.7125962 0.49117383 -0.3119098 0.15369919 0.39569452 -1.18983 1.2498406 -0.38752007 -0.54896545 -0.620718 0.02654967 -0.10994279 -0.39713857 -0.18035033 -1.554728 0.12178666 -0.5698373 1.1097438 -1.2350351 ]] GOEDEMORGEN 3 [-0.11578767001628876, 0.2837183475494385, 0.8604608178138733, -0.3106129765510559, -0.7815345525741577, -0.24914108216762543, -2.4438860416412354, -0.326715886592865, 0.29237526655197144, -0.7879325747489929, -0.12082034349441528, 0.32809633016586304, -1.0938974618911743, 0.7125961780548096, 0.4911738336086273, -0.3119097948074341, 0.15369918942451477, 0.3956945240497589, -1.18982994556427, 1.2498406171798706, -0.38752007484436035, -0.5489654541015625, -0.6207180023193359, 0.02654966711997986, -0.10994279384613037, -0.3971385657787323, -0.18035033345222473, -1.5547280311584473, 0.12178666144609451, -0.5698372721672058, 1.1097438335418701, -1.2350350618362427]
6 [[-0.9557147 0.03368077 1.0923759 -0.41077584 -1.6992741 -0.59566665 -1.1562454 0.8824008 -0.25122 0.625515 1.6720039 0.5639413 -0.35894975 -0.05896414 0.84854823 0.4072139 -1.1763096 -1.5435666 -1.5066593 0.66299886 -1.3121625 1.4410669 -0.6341419 -0.4873513 0.8113033 -0.01502099 0.32262453 2.5855515 -0.7294207 0.7582124 -0.1386989 -0.17011487]] GOEDENACHT 4 [-0.9557147026062012, 0.03368076682090759, 1.0923758745193481, -0.4107758402824402, -1.6992740631103516, -0.5956666469573975, -1.1562453508377075, 0.8824008107185364, -0.2512199878692627, 0.6255149841308594, 1.6720038652420044, 0.5639412999153137, -0.35894975066185, -0.058964136987924576, 0.8485482335090637, 0.40721389651298523, -1.176309585571289, -1.5435665845870972, -1.5066592693328857, 0.6629988551139832, -1.3121625185012817, 1.441066861152649, -0.6341419219970703, -0.48735129833221436, 0.8113033175468445, -0.015020990744233131, 0.322624534368515, 2.5855515003204346, -0.7294207215309143, 0.7582123875617981, -0.13869890570640564, -0.17011487483978271]
7 [[ 0.13687058 -1.099074 1.1867181 -1.2216619 -0.965039 -0.6957289 -1.6400954 0.7317837 0.5887569 0.18756844 1.4867041 0.63357025 -1.5853169 0.4157976 0.76919526 0.08512082 -1.0937556 -0.07339763 -2.8580015 1.6835626 -1.4538742 1.2302468 -0.49323368 -0.81243044 -0.11378415 -0.09592562 0.31165385 -0.32617894 0.02981896 -0.01485211 1.1880591 -0.40726334]] GOEDENAVOND 5 [0.1368705779314041, -1.0990740060806274, 1.1867181062698364, -1.221661925315857, -0.9650390148162842, -0.6957288980484009, -1.6400953531265259, 0.7317836880683899, 0.5887569189071655, 0.18756844103336334, 1.4867041110992432, 0.6335702538490295, -1.5853168964385986, 0.4157975912094116, 0.7691952586174011, 0.08512081950902939, -1.0937556028366089, -0.07339762896299362, -2.858001470565796, 1.6835626363754272, -1.4538742303848267, 1.2302467823028564, -0.49323368072509766, -0.8124304413795471, -0.11378414928913116, -0.09592562168836594, 0.31165385246276855, -0.3261789381504059, 0.029818959534168243, -0.014852114021778107, 1.1880590915679932, -0.4072633385658264]
8 [[ 1.3673804 1.7083405 0.58775777 0.26056585 -0.29101595 -0.69954723 0.45681733 1.6908163 0.02684825 0.7537957 2.2054908 0.28831166 -1.5712252 -1.6869793 0.8009781 -0.51264 -1.3544708 -0.72502786 -0.31009498 -0.23777717 -1.7524906 1.6333568 1.5263942 0.22317217 -0.40576094 2.2136266 1.4030526 -1.0599502 0.8686069 -1.1141618 -0.01899157 -1.2256656 ]] JA 6 [1.3673803806304932, 1.7083405256271362, 0.5877577662467957, 0.260565847158432, -0.29101595282554626, -0.6995472311973572, 0.4568173289299011, 1.6908162832260132, 0.02684824913740158, 0.7537956833839417, 2.205490827560425, 0.2883116602897644, -1.5712251663208008, -1.6869792938232422, 0.8009781241416931, -0.5126399993896484, -1.3544708490371704, -0.725027859210968, -0.3100949823856354, -0.23777717351913452, -1.7524906396865845, 1.6333568096160889, 1.526394248008728, 0.2231721729040146, -0.4057609438896179, 2.2136266231536865, 1.403052568435669, -1.0599502325057983, 0.8686069250106812, -1.1141618490219116, -0.01899157091975212, -1.22566556930542]
9 [[ 1.8208723 -0.59625745 0.15060106 -0.23053254 1.0344827 1.7016335 0.1968202 -0.48188782 0.3004466 -0.53336793 0.89704275 0.26869062 -1.0112952 0.0343281 0.32149622 -0.8189871 0.33571526 -0.57059044 -0.7577773 1.4743904 -0.01735806 -0.4116615 2.3637483 -0.6602343 -0.6101709 1.8139238 1.041902 -0.9006022 0.6082014 -0.23616463 -0.62204087 -0.524899 ]] LINKS 7 [1.8208723068237305, -0.5962574481964111, 0.15060105919837952, -0.23053254187107086, 1.034482717514038, 1.7016334533691406, 0.1968201994895935, -0.4818878173828125, 0.30044659972190857, -0.533367931842804, 0.8970427513122559, 0.2686906158924103, -1.011295199394226, 0.034328099340200424, 0.32149621844291687, -0.8189870715141296, 0.33571526408195496, -0.5705904364585876, -0.7577772736549377, 1.4743903875350952, -0.01735806092619896, -0.4116615056991577, 2.36374831199646, -0.660234272480011, -0.6101709008216858, 1.8139238357543945, 1.04190194606781, -0.9006022214889526, 0.6082013845443726, -0.2361646294593811, -0.622040867805481, -0.5248990058898926]
10 [[ 1.3968729 2.340378 0.567814 0.5684975 -0.8795973 -1.0090083 -0.01301649 1.9747832 -0.3257468 0.76960254 1.402801 0.39424965 0.06948688 -0.51904166 1.0961802 -0.34335235 -1.8730735 -1.0801809 0.2751789 -0.96998405 -1.9822828 2.1046805 1.3094746 0.5770473 -0.19693638 1.8973799 0.99536693 0.2668684 0.28499565 -0.9838486 -0.19578832 -0.46982569]] NEE 8 [1.396872878074646, 2.3403780460357666, 0.5678139925003052, 0.5684974789619446, -0.8795973062515259, -1.0090082883834839, -0.013016488403081894, 1.974783182144165, -0.3257468044757843, 0.7696025371551514, 1.4028010368347168, 0.39424964785575867, 0.06948687881231308, -0.5190416574478149, 1.0961802005767822, -0.343352347612381, -1.8730734586715698, -1.0801808834075928, 0.2751789093017578, -0.9699840545654297, -1.9822827577590942, 2.1046805381774902, 1.3094745874404907, 0.5770472884178162, -0.19693638384342194, 1.8973798751831055, 0.9953669309616089, 0.2668684124946594, 0.2849956452846527, -0.9838485717773438, -0.19578832387924194, -0.4698256850242615]
11 [[ 1.8143647 -0.9315294 0.28706568 -1.0938002 -0.2451038 0.42331484 -0.31653756 0.00268575 0.89822304 1.0639151 1.7406261 1.1707302 -1.6639041 0.30558044 0.22984704 -1.0260515 -0.08818819 -0.14696571 -1.5269523 0.55015475 -0.12936743 0.93951946 1.3917109 -0.62517196 -1.3869016 1.4833041 1.4647726 -1.0285482 -0.3704591 1.4672718 0.40315557 0.04754414]] RECHTS 9 [1.8143646717071533, -0.9315294027328491, 0.28706568479537964, -1.0938001871109009, -0.24510380625724792, 0.4233148396015167, -0.3165375590324402, 0.002685748040676117, 0.8982230424880981, 1.0639151334762573, 1.7406260967254639, 1.1707302331924438, -1.663904070854187, 0.3055804371833801, 0.2298470437526703, -1.0260515213012695, -0.08818819373846054, -0.14696571230888367, -1.5269522666931152, 0.5501547455787659, -0.1293674260377884, 0.939519464969635, 1.391710877418518, -0.625171959400177, -1.386901617050171, 1.4833041429519653, 1.4647725820541382, -1.028548240661621, -0.37045910954475403, 1.4672718048095703, 0.4031555652618408, 0.04754413664340973]
12 [[ 1.681777 -0.79441184 0.07110912 -0.01304399 1.110081 2.1115103 -0.27283993 -1.5035654 0.55927217 -0.9903696 0.05718866 0.19891238 -0.7289791 0.71738267 -0.41428873 -0.58389515 1.1087953 -0.6616588 -0.6330258 1.5830203 0.60013187 -1.0640136 2.3917787 -0.70012283 -0.7359694 0.9847692 0.94458187 -0.5915589 0.3662199 0.39359796 -0.6499281 -0.30565363]] SALUUT 10 [1.681777000427246, -0.794411838054657, 0.07110912352800369, -0.013043994084000587, 1.1100809574127197, 2.1115102767944336, -0.2728399336338043, -1.5035654306411743, 0.5592721700668335, -0.9903696179389954, 0.05718865618109703, 0.1989123821258545, -0.7289791107177734, 0.7173826694488525, -0.414288729429245, -0.5838951468467712, 1.1087952852249146, -0.6616588234901428, -0.6330258250236511, 1.5830203294754028, 0.6001318693161011, -1.0640136003494263, 2.3917787075042725, -0.7001228332519531, -0.7359694242477417, 0.9847692251205444, 0.9445818662643433, -0.5915588736534119, 0.3662199079990387, 0.39359796047210693, -0.649928092956543, -0.3056536316871643]
13 [[ 0.56773967 -0.6260798 0.23771092 -0.10016712 0.86817515 0.92371875 -0.16313902 -0.3785063 -0.41955388 -0.586288 -0.33712262 -0.07999519 -0.98128295 0.18787864 -0.36432508 -0.06605241 0.00710656 -0.36359936 -0.00642961 1.257384 0.64647824 -1.1618687 1.3707273 -0.46546304 -0.14522405 0.29306158 0.41898865 -0.12311607 0.24604419 -0.48434448 -0.79076517 0.753163 ]] SLECHT 11 [0.5677396655082703, -0.626079797744751, 0.23771092295646667, -0.10016711801290512, 0.8681751489639282, 0.9237187504768372, -0.16313901543617249, -0.37850630283355713, -0.41955387592315674, -0.5862879753112793, -0.3371226191520691, -0.07999519258737564, -0.9812829494476318, 0.18787863850593567, -0.36432507634162903, -0.06605241447687149, 0.007106557488441467, -0.36359935998916626, -0.006429608445614576, 1.257383942604065, 0.6464782357215881, -1.161868691444397, 1.370727300643921, -0.4654630422592163, -0.14522404968738556, 0.29306158423423767, 0.4189886450767517, -0.12311606854200363, 0.24604418873786926, -0.484344482421875, -0.7907651662826538, 0.7531629800796509]
14 [[ 0.9251696 -2.8536968 0.6521715 -1.8256427 0.44136596 0.33299872 -1.4674346 -0.7897836 0.47153538 0.1437631 0.7396151 1.0851275 -1.2100328 1.626544 0.69652647 0.0048107 0.13927785 -0.63199776 -3.0121155 2.1738136 -0.7935879 0.2744053 0.281097 -1.0899199 -0.7062637 -0.04824958 0.3436054 -0.69366074 0.22799876 0.8932708 0.51823634 -0.08065546]] SMAKELIJK 12 [0.9251695871353149, -2.853696823120117, 0.6521714925765991, -1.825642704963684, 0.44136595726013184, 0.33299872279167175, -1.4674346446990967, -0.7897835969924927, 0.4715353846549988, 0.14376309514045715, 0.7396150827407837, 1.0851274728775024, -1.2100328207015991, 1.6265439987182617, 0.6965264678001404, 0.004810698330402374, 0.139277845621109, -0.6319977641105652, -3.012115478515625, 2.173813581466675, -0.7935879230499268, 0.274405300617218, 0.2810969948768616, -1.089919924736023, -0.7062637209892273, -0.04824957996606827, 0.3436053991317749, -0.6936607360839844, 0.2279987633228302, 0.8932707905769348, 0.5182363390922546, -0.08065545558929443]
15 [[-1.460053 2.9393604 0.5533726 1.3786266 -1.5367306 -1.2867874 -1.2442408 0.62938243 -0.7092637 -1.0278705 -1.3296479 -0.8105969 0.69997054 0.16618343 0.5113248 0.15563956 -0.4815752 -0.10130378 1.4232539 -1.2767068 -0.3229611 0.45140803 -1.3901687 1.0403453 1.454544 -1.308266 -1.145199 0.661388 -0.20519465 -1.0480955 -0.04290716 -0.36275834]] SORRY 13 [-1.4600529670715332, 2.9393603801727295, 0.5533726215362549, 1.3786265850067139, -1.5367306470870972, -1.2867873907089233, -1.2442407608032227, 0.6293824315071106, -0.7092636823654175, -1.027870535850525, -1.3296478986740112, -0.8105968832969666, 0.699970543384552, 0.16618342697620392, 0.5113248229026794, 0.15563955903053284, -0.4815751910209656, -0.10130377858877182, 1.4232538938522339, -1.2767068147659302, -0.32296109199523926, 0.4514080286026001, -1.3901686668395996, 1.040345311164856, 1.454543948173523, -1.308266043663025, -1.145198941230774, 0.6613879799842834, -0.2051946520805359, -1.048095464706421, -0.04290715977549553, -0.3627583384513855]
16 [[ 1.0733138 -0.66348886 0.36737278 0.01765811 0.5730918 1.7617474 -0.45580968 -1.0973133 0.4730748 -0.4665998 0.48594972 0.548745 -0.91712314 0.4846773 0.3774526 -0.5996147 0.28750947 -1.0166394 -0.6239797 1.7935088 -0.03666669 -0.51941574 2.045076 -0.9045104 -0.6312211 0.9698636 1.0522215 -0.14772844 0.7406032 0.2901447 -0.48710442 -0.4619298 ]] TOT-ZIENS 14 [1.07331383228302, -0.6634888648986816, 0.3673727810382843, 0.017658105120062828, 0.5730918049812317, 1.7617473602294922, -0.45580968260765076, -1.0973132848739624, 0.4730747938156128, -0.46659979224205017, 0.48594972491264343, 0.5487449765205383, -0.9171231389045715, 0.4846773147583008, 0.3774526119232178, -0.599614679813385, 0.2875094711780548, -1.0166393518447876, -0.6239796876907349, 1.793508768081665, -0.036666687577962875, -0.5194157361984253, 2.0450758934020996, -0.9045103788375854, -0.6312211155891418, 0.9698635935783386, 1.0522215366363525, -0.14772844314575195, 0.7406032085418701, 0.2901447117328644, -0.4871044158935547, -0.4619297981262207]

97
export_embeddings.py Normal file
View File

@ -0,0 +1,97 @@
import multiprocessing
import os
import torch
import argparse
from datasets.dataset_loader import LocalDatasetLoader
from datasets.embedding_dataset import SLREmbeddingDataset
from torch.utils.data import DataLoader
from datasets import SLREmbeddingDataset, collate_fn_padd
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
import numpy as np
import random
import pandas as pd
seed = 43
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
generator = torch.Generator()
generator.manual_seed(seed)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
generator = torch.Generator()
generator.manual_seed(seed)
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def parse_args():
parser = argparse.ArgumentParser(description='Export embeddings')
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
parser.add_argument('--output', type=str, default=None, help='Path to output')
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
parser.add_argument('--format', type=str, default='csv', help='Format of the output file (csv, json)')
args = parser.parse_args()
return args
args = parse_args()
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
# load the model
checkpoint = torch.load(args.checkpoint, map_location=device)
model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
).to(device)
model.load_state_dict(checkpoint["state_dict"])
dataset_loader = LocalDatasetLoader()
dataset = SLREmbeddingDataset(args.dataset, triplet=False, augmentations=False)
data_loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=collate_fn_padd,
pin_memory=torch.cuda.is_available(),
#num_workers=0, # Uncomment this line (and comment out next line) if you want to disable multithreading
num_workers=multiprocessing.cpu_count(),
worker_init_fn=seed_worker,
generator=generator,
)
embeddings = []
k = 0
with torch.no_grad():
for i, (inputs, labels, masks) in enumerate(data_loader):
k += 1
inputs = inputs.to(device)
masks = masks.to(device)
outputs = model(inputs, masks)
for n in range(outputs.shape[0]):
embeddings.append(outputs[n].cpu().numpy())
df = pd.read_csv(args.dataset)
df["embeddings"] = embeddings
df = df[['embeddings', 'label_name', 'labels']]
df['embeddings'] = df['embeddings'].apply(lambda x: x.tolist()[0])
if args.format == 'json':
df.to_json(args.output, orient='records')
elif args.format == 'csv':
df.to_csv(args.output, index=False)

67
export_model.py Normal file
View File

@ -0,0 +1,67 @@
# to run this script, you need torch 1.13.1 and torchvision 0.14.1
import numpy as np
import onnx
import torch
import torchvision
import os
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
# set parameters of the model
model_name = 'fingerspelling_embedding_model'
# load PyTorch model from .pth file
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda")
CHECKPOINT_PATH = "checkpoints/fingerspelling_checkpoint.pth"
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
).to(device)
model.load_state_dict(checkpoint["state_dict"])
# set model to evaluation mode
model.eval()
dummy_input = torch.randn(1, 10, 54, 2)
# check if models folder exists
if not os.path.exists('out-models'):
os.makedirs('out-models')
for model_export in ["onnx", "coreml"]:
if model_export == "coreml":
# set device for dummy input
dummy_input = dummy_input.to(device)
traced_model = torch.jit.trace(model, dummy_input)
out = traced_model(dummy_input)
import coremltools as ct
# Convert to Core ML
coreml_model = ct.convert(
traced_model,
inputs=[ct.TensorType(name="input", shape=dummy_input.shape)],
)
# Save Core ML model
coreml_model.save("out-models/" + model_name + ".mlmodel")
else:
# set device for dummy input
dummy_input = dummy_input.to(device)
torch.onnx.export(model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
'out-models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=9, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['X'], # the model's input names
output_names = ['Y'] # the model's output names
)

73
hyperparam_opt.py Normal file
View File

@ -0,0 +1,73 @@
from clearml.automation import UniformParameterRange, UniformIntegerParameterRange, DiscreteParameterRange
from clearml.automation import HyperParameterOptimizer
from clearml.automation.optuna import OptimizerOptuna
from optuna.pruners import HyperbandPruner, MedianPruner
from clearml import Task
task = Task.init(
project_name='SpoterEmbedding',
task_name='Automatic Hyper-Parameter Optimization',
task_type=Task.TaskTypes.optimizer,
reuse_last_task_id=False
)
optimizer = HyperParameterOptimizer(
# specifying the task to be optimized, task must be in system already so it can be cloned
base_task_id="4504e0b3ec6745249d3d4c94d3d40652",
# setting the hyperparameters to optimize
hyper_parameters=[
# epochs:
DiscreteParameterRange('Args/epochs', [200]),
# learning rate
UniformParameterRange('Args/lr', 0.000001, 0.01),
# optimizer
DiscreteParameterRange('Args/optimizer', ['ADAM', 'SGD']),
# vector length
UniformIntegerParameterRange('Args/vector_length', 10, 100),
],
# setting the objective metric we want to maximize/minimize
objective_metric_title='train_loss',
objective_metric_series='loss',
objective_metric_sign='min',
# setting optimizer
optimizer_class=OptimizerOptuna,
# configuring optimization parameters
execution_queue='default',
optimization_time_limit=360,
compute_time_limit=480,
total_max_jobs=20,
min_iteration_per_job=0,
max_iteration_per_job=150000,
pool_period_min=0.1,
save_top_k_tasks_only=3,
optuna_pruner=MedianPruner(),
)
def job_complete_callback(
job_id, # type: str
objective_value, # type: float
objective_iteration, # type: int
job_parameters, # type: dict
top_performance_job_id # type: str
):
print('Job completed!', job_id, objective_value, objective_iteration, job_parameters)
if job_id == top_performance_job_id:
print('WOOT WOOT we broke the record! Objective reached {}'.format(objective_value))
task.execute_remotely(queue_name='hypertuning', exit_process=True)
optimizer.set_report_period(0.3)
optimizer.start(job_complete_callback=job_complete_callback)
optimizer.wait()
top_exp = optimizer.get_top_experiments(top_k=3)
print([t.id for t in top_exp])
optimizer.stop()

View File

@ -88,9 +88,10 @@ def train_epoch_embedding_online(model, epoch_iters, train_loader, val_loader, c
if enable_batch_sorting:
if labels_size < train_loader.batch_size:
trim_count = labels_size % mini_batch
inputs = inputs[:-trim_count]
labels = labels[:-trim_count]
masks = masks[:-trim_count]
if trim_count > 0:
inputs = inputs[:-trim_count]
labels = labels[:-trim_count]
masks = masks[:-trim_count]
embeddings = None
with torch.no_grad():
for j in range(batch_loop_count):

View File

@ -61,20 +61,25 @@ def map_blazepose_keypoint(column):
return f"{mapped}_{hand}{suffix}"
def map_blazepose_df(df):
def map_blazepose_df(df, rename=True):
to_drop = []
if rename:
renamings = {}
for column in df.columns:
mapped_column = map_blazepose_keypoint(column)
if mapped_column:
renamings[column] = mapped_column
else:
to_drop.append(column)
df = df.rename(columns=renamings)
for index, row in df.iterrows():
sequence_size = len(row["leftEar_Y"])
lsx = row["leftShoulder_X"]
rsx = row["rightShoulder_X"]
lsy = row["leftShoulder_Y"]
rsy = row["rightShoulder_Y"]
# convert all to list
lsx = lsx[1:-1].split(",")
rsx = rsx[1:-1].split(",")
lsy = lsy[1:-1].split(",")
rsy = rsy[1:-1].split(",")
sequence_size = len(lsx)
neck_x = []
neck_y = []
# Treat each element of the sequence (analyzed frame) individually
@ -84,4 +89,5 @@ def map_blazepose_df(df):
df.loc[index, "neck_X"] = str(neck_x)
df.loc[index, "neck_Y"] = str(neck_y)
df.drop(columns=to_drop, inplace=True)
return df

View File

@ -5,30 +5,30 @@ import pandas as pd
from normalization.hand_normalization import normalize_hands_full
from normalization.body_normalization import normalize_body_full
DATASET_PATH = './data/wlasl'
DATASET_PATH = './data/processed'
# Load the dataset
df = pd.read_csv(os.path.join(DATASET_PATH, "WLASL100_train.csv"), encoding="utf-8")
df = pd.read_csv(os.path.join(DATASET_PATH, "spoter_train.csv"), encoding="utf-8")
print(df.head())
print(df.columns)
# Retrieve metadata
video_size_heights = df["video_height"].to_list()
video_size_widths = df["video_width"].to_list()
# video_size_heights = df["video_height"].to_list()
# video_size_widths = df["video_width"].to_list()
# Delete redundant (non-related) properties
del df["video_height"]
del df["video_width"]
# del df["video_height"]
# del df["video_width"]
# Temporarily remove other relevant metadata
labels = df["labels"].to_list()
video_fps = df["fps"].to_list()
signs = df["sign"].to_list()
del df["labels"]
del df["fps"]
del df["split"]
del df["video_id"]
del df["label_name"]
del df["length"]
del df["sign"]
del df["path"]
del df["participant_id"]
del df["sequence_id"]
# Convert the strings into lists
@ -41,7 +41,7 @@ for column in df.columns:
# Perform the normalizations
df = normalize_hands_full(df)
df, invalid_row_indexes = normalize_body_full(df)
# df, invalid_row_indexes = normalize_body_full(df)
# Clear lists of items from deleted rows
# labels = [t for i, t in enumerate(labels) if i not in invalid_row_indexes]
@ -49,6 +49,6 @@ df, invalid_row_indexes = normalize_body_full(df)
# Return the metadata back to the dataset
df["labels"] = labels
df["fps"] = video_fps
df["sign"] = signs
df.to_csv(os.path.join(DATASET_PATH, "wlasl_train_norm.csv"), encoding="utf-8", index=False)
df.to_csv(os.path.join(DATASET_PATH, "spoter_train_norm.csv"), encoding="utf-8", index=False)

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"id": "c20f7fd5",
"metadata": {},
"outputs": [],
@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"id": "ada032d0",
"metadata": {},
"outputs": [],
@ -22,13 +22,12 @@
"import os\n",
"import os.path as op\n",
"import pandas as pd\n",
"import json\n",
"import base64"
"import json"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"id": "05682e73",
"metadata": {},
"outputs": [],
@ -38,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 8,
"id": "fede7684",
"metadata": {},
"outputs": [],
@ -48,7 +47,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"id": "ce531994",
"metadata": {},
"outputs": [],
@ -64,7 +63,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"id": "f4a2d672",
"metadata": {},
"outputs": [],
@ -87,17 +86,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"id": "1d9db764",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f29f89e3ed0>"
"<torch._C.Generator at 0x7f010919d710>"
]
},
"execution_count": 7,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@ -119,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 12,
"id": "71224139",
"metadata": {},
"outputs": [],
@ -155,7 +154,7 @@
"# checkpoint = torch.load(model.get_weights())\n",
"\n",
"## Set your path to checkoint here\n",
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_6.pth\"\n",
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_1105.pth\"\n",
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
"\n",
"model = SPOTER_EMBEDDINGS(\n",
@ -169,27 +168,28 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 24,
"id": "ba6b58f0",
"metadata": {},
"outputs": [],
"source": [
"SL_DATASET = 'wlasl' # or 'lsa'\n",
"if SL_DATASET == 'wlasl':\n",
"SL_DATASET = 'basic-signs' # or 'wlasl'\n",
"\n",
"if SL_DATASET == 'fingerspelling':\n",
" dataset_name = \"fingerspelling\"\n",
" split_dataset_path = \"fingerspelling_{}.csv\"\n",
"elif SL_DATASET == 'wlasl':\n",
" dataset_name = \"wlasl\"\n",
" num_classes = 100\n",
" split_dataset_path = \"WLASL100_train.csv\"\n",
"else:\n",
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
" num_classes = 64\n",
" split_dataset_path = \"LSA64_{}.csv\"\n",
" \n",
" split_dataset_path = \"WLASL100_{}.csv\"\n",
"elif SL_DATASET == 'basic-signs':\n",
" dataset_name = \"basic-signs\"\n",
" split_dataset_path = \"basic-signs_{}.csv\"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 25,
"id": "5643a72c",
"metadata": {},
"outputs": [],
@ -228,7 +228,7 @@
"outputs": [],
"source": [
"dataloaders = {}\n",
"splits = ['train', 'val']\n",
"splits = ['train', 'val']\n",
"dfs = {}\n",
"for split in splits:\n",
" split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n",
@ -269,6 +269,8 @@
" for i, (inputs, labels, masks) in enumerate(dataloader):\n",
" k += 1\n",
" inputs = inputs.to(device)\n",
" \n",
"\n",
" masks = masks.to(device)\n",
" outputs = model(inputs, masks)\n",
" for n in range(outputs.shape[0]):\n",
@ -285,7 +287,7 @@
{
"data": {
"text/plain": [
"(810, 810)"
"(164, 164)"
]
},
"execution_count": 19,
@ -299,7 +301,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 21,
"id": "ab83c6e2",
"metadata": {
"lines_to_next_cell": 2
@ -311,6 +313,70 @@
" df['embeddings'] = embeddings_split[split]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0b9fb9c2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 [1.7327625, -3.015248, -1.4775522, -0.7505071,...\n",
"1 [2.0936582, -0.596195, -0.7918601, -0.15896143...\n",
"2 [-1.4007742, -0.9608915, 1.3294879, -0.5185398...\n",
"3 [1.3280737, -3.299126, -1.0110444, -1.2528414,...\n",
"4 [-0.071124956, -0.79259753, 0.7182858, 0.38130...\n",
" ... \n",
"159 [-1.5968355, 1.9617733, 0.28859574, 1.256657, ...\n",
"160 [0.44801116, -1.8377966, 1.1004394, -1.195648,...\n",
"161 [2.0584257, 1.6986116, 0.5129896, 0.27279535, ...\n",
"162 [1.6695516, -2.967027, -1.5715427, -0.77170163...\n",
"163 [1.4977738, -2.6278958, -1.6123883, -0.8420623...\n",
"Name: embeddings, Length: 164, dtype: object\n",
"0 TOT-ZIENS\n",
"1 GOED\n",
"2 GOEDENACHT\n",
"3 NEE\n",
"4 SLECHT\n",
" ... \n",
"159 SORRY\n",
"160 GOEDEMORGEN\n",
"161 LINKS\n",
"162 TOT-ZIENS\n",
"163 GOED\n",
"Name: label_name, Length: 164, dtype: object\n",
"0 0\n",
"1 1\n",
"2 2\n",
"3 3\n",
"4 4\n",
" ..\n",
"159 7\n",
"160 5\n",
"161 13\n",
"162 0\n",
"163 1\n",
"Name: labels, Length: 164, dtype: int64\n"
]
}
],
"source": [
"print(dfs['train'][\"embeddings\"])\n",
"print(dfs['train'][\"label_name\"])\n",
"print(dfs['train'][\"labels\"])\n",
"\n",
"# only keep these columns\n",
"dfs['train'] = dfs['train'][['embeddings', 'label_name', 'labels']]\n",
"\n",
"# convert embeddings to string\n",
"dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n",
"\n",
"# save the dfs['train']\n",
"dfs['train'].to_csv(f'../data/{dataset_name}/embeddings.csv', index=False)"
]
},
{
"cell_type": "markdown",
"id": "2951638d",
@ -322,7 +388,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 23,
"id": "7399b8ae",
"metadata": {},
"outputs": [
@ -331,16 +397,16 @@
"output_type": "stream",
"text": [
"Using centroids only\n",
"Top-1 accuracy: 5.19 %\n",
"Top-5 embeddings class match: 17.65 % (Picks any class in the 5 closest embeddings)\n",
"Top-1 accuracy: 80.00 %\n",
"Top-5 embeddings class match: 93.33 % (Picks any class in the 5 closest embeddings)\n",
"\n",
"################################\n",
"\n",
"Using all embeddings\n",
"Top-1 accuracy: 5.31 %\n",
"5-nn accuracy: 5.56 % (Picks the class that appears most often in the 5 closest embeddings)\n",
"Top-5 embeddings class match: 15.43 % (Picks any class in the 5 closest embeddings)\n",
"Top-5 unique class match: 15.56 % (Picks the 5 closest distinct classes)\n",
"Top-1 accuracy: 80.00 %\n",
"5-nn accuracy: 80.00 % (Picks the class that appears most often in the 5 closest embeddings)\n",
"Top-5 embeddings class match: 86.67 % (Picks any class in the 5 closest embeddings)\n",
"Top-5 unique class match: 93.33 % (Picks the 5 closest distinct classes)\n",
"\n",
"################################\n",
"\n"
@ -375,13 +441,13 @@
" sorted_labels = labels[argsort]\n",
" if sorted_labels[0] == true_label:\n",
" top1 += 1\n",
" if use_centroids:\n",
" good_samples.append(df_val.loc[i, 'video_id'])\n",
" else:\n",
" good_samples.append((df_val.loc[i, 'video_id'],\n",
" df_train.loc[argsort[0], 'video_id'],\n",
" i,\n",
" argsort[0]))\n",
" # if use_centroids:\n",
" # good_samples.append(df_val.loc[i, 'video_id'])\n",
" # else:\n",
" # good_samples.append((df_val.loc[i, 'video_id'],\n",
" # df_train.loc[argsort[0], 'video_id'],\n",
" # i,\n",
" # argsort[0]))\n",
"\n",
"\n",
" if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n",

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,5 @@
from argparse import ArgumentParser
from preprocessing.create_wlasl_landmarks_dataset import parse_create_args, create
from preprocessing.create_fingerspelling_dataset import parse_create_args, create
from preprocessing.extract_mediapipe_landmarks import parse_extract_args, extract

View File

@ -0,0 +1,174 @@
import os
import os.path as op
import json
import shutil
import cv2
import mediapipe as mp
import numpy as np
import pandas as pd
from utils import get_logger
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from normalization.blazepose_mapping import map_blazepose_df
BASE_DATA_FOLDER = 'data/'
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_hands = mp.solutions.hands
mp_holistic = mp.solutions.holistic
pose_landmarks = mp_holistic.PoseLandmark
hand_landmarks = mp_holistic.HandLandmark
def get_landmarks_names():
'''
Returns landmark names for mediapipe holistic model
'''
pose_lmks = ','.join([f'{lmk.name.lower()}_x,{lmk.name.lower()}_y' for lmk in pose_landmarks])
left_hand_lmks = ','.join([f'left_hand_{lmk.name.lower()}_x,left_hand_{lmk.name.lower()}_y'
for lmk in hand_landmarks])
right_hand_lmks = ','.join([f'right_hand_{lmk.name.lower()}_x,right_hand_{lmk.name.lower()}_y'
for lmk in hand_landmarks])
lmks_names = f'{pose_lmks},{left_hand_lmks},{right_hand_lmks}'
return lmks_names
def convert_to_str(arr, precision=6):
if isinstance(arr, np.ndarray):
values = []
for val in arr:
if val == 0:
values.append('0')
else:
values.append(f'{val:.{precision}f}')
return f"[{','.join(values)}]"
else:
return str(arr)
def parse_create_args(parser):
parser.add_argument('--landmarks-dataset', '-lmks', required=True,
help='Path to folder with landmarks npy files. \
You need to run `extract_mediapipe_landmarks.py` script first')
parser.add_argument('--dataset-folder', '-df', default='data/wlasl',
help='Path to folder where original `WLASL_v0.3.json` and `id_to_label.json` are stored. \
Note that final CSV files will be saved in this folder too.')
parser.add_argument('--videos-folder', '-videos', default=None,
help='Path to folder with videos. If None, then no information of videos (fps, length, \
width and height) will be stored in final csv file')
parser.add_argument('--num-classes', '-nc', default=100, type=int, help='Number of classes to use in WLASL dataset')
parser.add_argument('--create-new-split', action='store_true')
parser.add_argument('--test-size', '-ts', default=0.25, type=float,
help='Test split percentage size. Only required if --create-new-split is set')
# python3 preprocessing.py --landmarks-dataset=data/landmarks -videos data/wlasl/videos
def create(args):
logger = get_logger(__name__)
landmarks_dataset = args.landmarks_dataset
videos_folder = args.videos_folder
dataset_folder = args.dataset_folder
num_classes = args.num_classes
test_size = args.test_size
os.makedirs(dataset_folder, exist_ok=True)
# shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/id_to_label.json'), dataset_folder)
# shutil.copy(os.path.join(BASE_DATA_FOLDER, 'wlasl/WLASL_v0.3.json'), dataset_folder)
# get files in landmarks_dataset folder
landmarks_files = os.listdir(landmarks_dataset)
video_data = []
for i, file in enumerate(tqdm(landmarks_files)):
# split by !
label = file.split('!')[0]
subset = file.split('!')[1].split('.')[0]
# remove npy and set mp4
video_id = file.replace('.npy', "")
video_dict = {'video_id': video_id,
'label_name': label,
'split': subset}
if videos_folder is not None:
cap = cv2.VideoCapture(op.join(videos_folder, f'{video_id}.mp4'))
if not cap.isOpened():
logger.warning(f'Video {video_id}.mp4 not found')
continue
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
fps = cap.get(cv2.CAP_PROP_FPS)
length = cap.get(cv2.CAP_PROP_FRAME_COUNT) / float(cap.get(cv2.CAP_PROP_FPS))
video_info = {'video_width': width,
'video_height': height,
'fps': fps,
'length': length}
video_dict.update(video_info)
video_data.append(video_dict)
df_video = pd.DataFrame(video_data)
video_ids = df_video['video_id'].unique()
lmks_data = []
lmks_names = get_landmarks_names().split(',')
# get labels from df_video
labels = df_video['label_name'].unique()
# map labels to ids
label_to_id = {label: i for i, label in enumerate(labels)}
# add label_id column to df_video
df_video['labels'] = df_video['label_name'].map(label_to_id)
# export to json file as id to label
id_to_label = {i: label for label, i in label_to_id.items()}
with open(op.join(dataset_folder, 'id_to_label.json'), 'w') as f:
json.dump(id_to_label, f, indent=4)
for video_id in video_ids:
lmk_fn = op.join(landmarks_dataset, f'{video_id}.npy')
if not op.exists(lmk_fn):
logger.warning(f'{lmk_fn} file not found. Skipping')
continue
lmk = np.load(lmk_fn).T
lmks_dict = {'video_id': video_id}
for lmk_, name in zip(lmk, lmks_names):
lmks_dict[name] = lmk_
lmks_data.append(lmks_dict)
df_lmks = pd.DataFrame(lmks_data)
df = pd.merge(df_video, df_lmks)
aux_columns = ['split', 'video_id', 'labels', 'label_name']
if videos_folder is not None:
aux_columns += ['video_width', 'video_height', 'fps', 'length']
df_aux = df[aux_columns]
df = map_blazepose_df(df)
df = pd.concat([df, df_aux], axis=1)
if args.create_new_split:
df_train, df_test = train_test_split(df, test_size=test_size, stratify=df['labels'], random_state=42)
print(f'Num classes: {num_classes}')
print(df_train['labels'].value_counts())
print(df_test['labels'].value_counts())
assert set(df_train['labels'].unique()) == set(df_test['labels'].unique(
)), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \
the --create-new-split flag'
for split, df_split in zip(['train', 'val'],
[df_train, df_test]):
fn_out = op.join(dataset_folder, f'{split}.csv')
(df_split.reset_index(drop=True)
.applymap(convert_to_str)
.to_csv(fn_out, index=False))
else:
fn_out = op.join(dataset_folder, 'train.csv')
(df.reset_index(drop=True)
.applymap(convert_to_str)
.to_csv(fn_out, index=False))

View File

@ -4,6 +4,8 @@ import pandas as pd
from tqdm.auto import tqdm
import json
from normalization.blazepose_mapping import map_blazepose_df
def create(train_landmark_files, train_csv, dataset_folder, test_size):
os.makedirs(dataset_folder, exist_ok=True)
@ -17,15 +19,15 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
mapping = {
'pose_0': 'nose',
'pose_1': 'leftEye',
'pose_2': 'rightEye',
'pose_3': 'leftEar',
'pose_4': 'rightEar',
'pose_5': 'leftShoulder',
'pose_6': 'rightShoulder',
'pose_7': 'leftElbow',
'pose_8': 'rightElbow',
'pose_9': 'leftWrist',
'pose_10': 'rightWrist',
'pose_4': 'rightEye',
'pose_7': 'leftEar',
'pose_8': 'rightEar',
'pose_11': 'leftShoulder',
'pose_12': 'rightShoulder',
'pose_13': 'leftElbow',
'pose_14': 'rightElbow',
'pose_15': 'leftWrist',
'pose_16': 'rightWrist',
'left_hand_0': 'wrist_left',
'left_hand_1': 'thumbCMC_left',
@ -77,7 +79,7 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
columns.append(f'{v}_X')
columns.append(f'{v}_Y')
for _, row in tqdm(train_df.head(6000).iterrows(), total=6000):
for _, row in tqdm(train_df.head(10000).iterrows(), total=10000):
path, participant_id, sequence_id, sign = row['path'], row['participant_id'], row['sequence_id'], row['sign']
parquet_file = os.path.join(train_landmark_files, str(participant_id), f"{sequence_id}.parquet")
@ -136,6 +138,7 @@ def create(train_landmark_files, train_csv, dataset_folder, test_size):
video_data.append(new_landmark_data)
video_data = pd.concat(video_data, axis=0, ignore_index=True)
video_data = map_blazepose_df(video_data, rename=False)
video_data.to_csv(os.path.join(dataset_folder, 'spoter.csv'), index=False)
train_landmark_files = 'data/train_landmark_files'

View File

@ -110,6 +110,7 @@ def create(args):
'length': length}
video_dict.update(video_info)
video_data.append(video_dict)
df_video = pd.DataFrame(video_data)
video_ids = df_video['video_id'].unique()
lmks_data = []
@ -126,9 +127,7 @@ def create(args):
lmks_data.append(lmks_dict)
df_lmks = pd.DataFrame(lmks_data)
print(df_lmks)
df = pd.merge(df_video, df_lmks)
print(df)
aux_columns = ['split', 'video_id', 'labels', 'label_name']
if videos_folder is not None:
aux_columns += ['video_width', 'video_height', 'fps', 'length']

View File

@ -37,6 +37,10 @@ class LandmarksResults:
self.num_landmarks_pose = num_landmarks_pose
self.num_landmarks_hand = num_landmarks_hand
@property
def empty(self):
return self.results.pose_landmarks is None or (self.results.left_hand_landmarks is None and self.results.right_hand_landmarks is None)
@property
def pose_landmarks(self):
if self.results.pose_landmarks is None:
@ -67,6 +71,10 @@ def get_landmarks(image_orig, holistic, debug=False):
# Convert the BGR image to RGB before processing.
image = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
results = LandmarksResults(holistic.process(image))
if results.empty:
return None
if debug:
lmks_pose = []
for lmk in results.pose_landmarks:
@ -94,6 +102,7 @@ def get_landmarks(image_orig, holistic, debug=False):
len(lmks_right_hand) == 2 * LEN_LANDMARKS_HAND
), f"{len(lmks_right_hand)} != {2 * LEN_LANDMARKS_HAND}"
landmarks = []
for lmk in chain(
results.pose_landmarks,
results.left_hand_landmarks,
@ -128,10 +137,21 @@ def extract(args):
videos_folder = args.videos_folder
os.makedirs(landmarks_output, exist_ok=True)
for fn_video in tqdm(sorted(glob.glob(op.join(videos_folder, "*mp4")))):
# check if landmarks already exist
if op.exists(op.join(landmarks_output, op.basename(fn_video).split(".")[0] + ".npy")):
continue
cap = cv2.VideoCapture(fn_video)
ret, image_orig = cap.read()
height, width = image_orig.shape[:2]
landmarks_video = []
# make sure fps is 20 by determining the number of frames to be skipped
frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
frame_skip = (frame_rate // 10) - 1
with tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))) as pbar:
with mp_holistic.Holistic(
static_image_mode=False,
@ -145,7 +165,11 @@ def extract(args):
print(e)
landmarks = get_landmarks(image_orig, holistic, debug=True)
ret, image_orig = cap.read()
landmarks_video.append(landmarks)
for _ in range(frame_skip):
ret, image_orig = cap.read()
pbar.update(1)
if landmarks:
landmarks_video.append(landmarks)
pbar.update(1)
landmarks_video = np.vstack(landmarks_video)
np.save(

View File

@ -8,7 +8,6 @@ dataset = "data/processed/spoter.csv"
# read the dataset
df = pd.read_csv(dataset)
df = map_blazepose_df(df)
with open("data/sign_to_prediction_index_map.json", "r") as f:
sign_to_prediction_index_max = json.load(f)
@ -17,6 +16,9 @@ with open("data/sign_to_prediction_index_map.json", "r") as f:
# filter df to make sure each sign has at least 4 samples
df = df[df["sign"].map(df["sign"].value_counts()) > 4]
# print number of unique signs
print("Number of unique signs: ", len(df["sign"].unique()))
# use the path column to split the dataset
paths = df["path"].unique()

View File

@ -1,7 +1,6 @@
pandas
bokeh==2.4.3
boto3>=1.9
clearml==1.6.4
ipywidgets==8.0.4
matplotlib==3.5.3
mediapipe==0.8.11
@ -9,6 +8,9 @@ notebook==6.5.2
opencv-python==4.6.0.66
plotly==5.11.0
scikit-learn==1.0.2
torch
torchvision
clearml==1.10.3
torch==2.0.0
torchvision==0.15.1
tqdm==4.54.1
optuna==3.1.1
onnx==1.14.0

View File

@ -15,7 +15,7 @@ from torchvision import transforms
from torch.utils.data import DataLoader
from pathlib import Path
import copy
import numpy as np
from datasets import CzechSLRDataset, SLREmbeddingDataset, collate_fn_triplet_padd, collate_fn_padd
from models import SPOTER, SPOTER_EMBEDDINGS, train_epoch, evaluate, train_epoch_embedding, \
train_epoch_embedding_online, evaluate_embedding
@ -32,7 +32,7 @@ except ImportError:
pass
PROJECT_NAME = "spoter"
PROJECT_NAME = "SpoterEmbedding"
CLEARML = "clearml"
@ -75,12 +75,25 @@ def train(args, tracker: Tracker):
# Construct the model
if not args.classification_model:
slrt_model = SPOTER_EMBEDDINGS(
features=args.vector_length,
hidden_dim=args.hidden_dim,
norm_emb=args.normalize_embeddings,
dropout=args.dropout
)
# if finetune, load the weights from the classification model
if args.finetune:
checkpoint = torch.load(args.checkpoint_path, map_location=device)
slrt_model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
dropout=checkpoint["config_args"].dropout,
)
else:
slrt_model = SPOTER_EMBEDDINGS(
features=args.vector_length,
hidden_dim=args.hidden_dim,
norm_emb=args.normalize_embeddings,
dropout=args.dropout,
)
model_type = 'embed'
if args.hard_triplet_mining == "None":
cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2)
@ -233,6 +246,9 @@ def train(args, tracker: Tracker):
val_accs.append(val_acc)
tracker.log_scalar_metric("acc", "val", epoch, val_acc)
create_embedding_scatter_plots(tracker, slrt_model, train_loader, val_loader, device, id_to_label, epoch,
top_model_name)
logger.info(f"Epoch time: {datetime.now() - start_time}")
logger.info("[" + str(epoch) + "] TRAIN loss: " + str(train_loss) + " acc: " + str(train_accs[-1]))
logger.info("[" + str(epoch) + "] VALIDATION acc: " + str(val_accs[-1]))

View File

@ -1,22 +1,24 @@
#!/bin/sh
python -m train \
python3 -m train \
--save_checkpoints_every 10 \
--experiment_name "augment_rotate_75_x8" \
--epochs 300 \
--experiment_name "Finetune Fingerspelling Signs" \
--epochs 1000 \
--optimizer "ADAM" \
--lr 0.001 \
--batch_size 16 \
--dataset_name "processed" \
--training_set_path "spoter_train.csv" \
--validation_set_path "spoter_test.csv" \
--lr 0.00001 \
--batch_size 8 \
--dataset_name "FingerSpelling" \
--training_set_path "train.csv" \
--validation_set_path "val.csv" \
--vector_length 32 \
--epoch_iters -1 \
--scheduler_factor 0 \
--hard_triplet_mining "in_batch" \
--filter_easy_triplets \
--triplet_loss_margin 1 \
--start_mining_hard 50 \
--triplet_loss_margin 4 \
--dropout 0.2 \
--augmentations_prob=0.75 \
--hard_mining_scheduler_triplets_threshold=0 \
--normalize_embeddings \
--num_classes 100 \
--tracker=clearml \
--dataset_loader=clearml \
--dataset_project="SpoterEmbedding" \
--finetune \
--checkpoint_path "checkpoints/checkpoint_embed_3835.pth"

View File

@ -81,4 +81,7 @@ def get_default_args():
help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \
distance threshold of the batch sorter")
parser.add_argument("--finetune", action='store_true', default=False, help="Fintune the model")
parser.add_argument("--checkpoint_path", type=str, default="")
return parser

1636
visualize_data.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@ -238,6 +238,7 @@ def distance_matrix(keypoints, embeddings, p=2, threshold=1000000):
f"{kk}-dimensional vectors")
if m*n*k <= threshold:
print("Using minkowski_distance")
return minkowski_distance(x[:,np.newaxis,:],y[np.newaxis,:,:],p)
else:
result = np.empty((m,n),dtype=float) # FIXME: figure out the best dtype