New basic signs model

This commit is contained in:
Jerome Coudron
2023-05-07 21:00:52 +00:00
committed by Jelle De Geest
parent 06aa9206ac
commit 43887af670
111 changed files with 952 additions and 329 deletions

View File

@@ -369,35 +369,20 @@ public class SignPredictor : MonoBehaviour
yield return new WaitUntil(() => t.IsCompleted);
model = t.Result;
modelID = modelList.GetCurrentModelIndex();
predictor_embed = new NatMLSignPredictorEmbed(model);
asyncPredictor = predictor_embed.ToAsync();
// Creating a KeypointManager
keypointManagerEmbedding = new KeypointManagerEmbedding();
if (modelID == ModelIndex.FINGERSPELLING)
{
predictor = new NatMLSignPredictor(model);
asyncPredictor = predictor.ToAsync();
// Creating a KeypointManager
keypointManager = new KeypointManager(modelInfoFile);
StartCoroutine(SignRecognitionCoroutine());
StartCoroutine(MediapipeCoroutine());
}
else
{
predictor_embed = new NatMLSignPredictorEmbed(model);
asyncPredictor = predictor_embed.ToAsync();
// Creating a KeypointManager
keypointManagerEmbedding = new KeypointManagerEmbedding();
// read the embedding data
embeddingDataList = JsonUtility.FromJson<EmbeddingDataList>($"{{\"dataList\":{modelInfoFileEmbedding}}}");
// Start the Coroutine
StartCoroutine(SignRecognitionCoroutineEmbed());
StartCoroutine(MediapipeCoroutineEmbed());
}
// read the embedding data
embeddingDataList = JsonUtility.FromJson<EmbeddingDataList>($"{{\"dataList\":{modelList.GetEmbeddings()}}}");
// Start the Coroutine
StartCoroutine(SignRecognitionCoroutineEmbed());
StartCoroutine(MediapipeCoroutineEmbed());
}
/*
/// <summary>
/// Coroutine which executes the mediapipe pipeline
/// </summary>
@@ -423,6 +408,7 @@ public class SignPredictor : MonoBehaviour
keypointManager.AddLandmarks(_poseLandmarks, _leftHandLandmarks, _rightHandLandmarks);
}
}
*/
/// <summary>
/// Coroutine which executes the mediapipe pipeline
@@ -491,6 +477,7 @@ public class SignPredictor : MonoBehaviour
return distances;
}
/*
/// <summary>
/// Coroutine which calls the sign predictor model
/// </summary>
@@ -541,6 +528,7 @@ public class SignPredictor : MonoBehaviour
}
}
*/
/// <summary>
/// Coroutine which calls the sign predictor embedding model
@@ -551,6 +539,7 @@ public class SignPredictor : MonoBehaviour
while (true)
{
List<List<List<float>>> inputData = keypointManagerEmbedding.GetKeypoints();
if (inputData != null && asyncPredictor.readyForPrediction)
{
// Getting the size of the input data
@@ -574,48 +563,32 @@ public class SignPredictor : MonoBehaviour
yield return new WaitUntil(() => task.IsCompleted);
List<float> result = task.Result;
if (0 < result.Count)
if (result.Count > 0)
{
List<DistanceEmbedding> distances = GetDistances(result, 2);
var probs = new Dictionary<string, float>();
learnableProbabilities = new Dictionary<string, float>();
for (int j = 0; j < distances.Count; j++)
{
DistanceEmbedding distanceEmbedding = distances[j];
// check if already in dictionary
if (probs.ContainsKey(distanceEmbedding.embeddingData.label_name))
if (learnableProbabilities.ContainsKey(distanceEmbedding.embeddingData.label_name))
{
// if so, check if the distance is smaller
if (probs[distanceEmbedding.embeddingData.label_name] > distanceEmbedding.distance)
if (learnableProbabilities[distanceEmbedding.embeddingData.label_name] > distanceEmbedding.distance)
{
// if so, replace the distance
probs[distanceEmbedding.embeddingData.label_name] = distanceEmbedding.distance;
learnableProbabilities[distanceEmbedding.embeddingData.label_name] = distanceEmbedding.distance;
}
}
else
{
// if not, add the distance to the dictionary
probs.Add(distanceEmbedding.embeddingData.label_name, distanceEmbedding.distance);
learnableProbabilities.Add(distanceEmbedding.embeddingData.label_name, distanceEmbedding.distance);
}
}
// convert distances to probabilities, the closer to 1.5 the better the prediction
var newProbs = new Dictionary<string, float>();
float sum = 0.0f;
foreach (KeyValuePair<string, float> entry in probs)
{
float probability = 1 / (1 + Mathf.Exp(2 * (entry.Value - 1.85f)));
newProbs.Add(entry.Key, probability);
sum += probability;
}
learnableProbabilities = new Dictionary<string, float>();
foreach (var kv in newProbs)
learnableProbabilities.Add(kv.Key, kv.Value / sum);
//UnityEngine.Debug.Log($"{learnableProbabilities.Aggregate("", (t, e) => $"{t}{e.Key}={e.Value}, ")}");
foreach (Listener listener in listeners)
{
yield return listener.ProcessIncomingCall();