Implemented Embedding Signpredictor

This commit is contained in:
2023-04-19 13:14:06 +02:00
parent 5b4a3ec4e7
commit db96a700e8
15 changed files with 617 additions and 13 deletions

View File

@@ -10,12 +10,103 @@ using System.Linq;
using System.Threading.Tasks;
using UnityEngine;
using UnityEngine.UI;
using System.IO;
[System.Serializable]
public class EmbeddingData
{
public float[] embeddings;
public string label_name;
public int labels;
}
[System.Serializable]
public class EmbeddingDataList
{
public List<EmbeddingData> dataList;
}
public class DistanceEmbedding
{
public float distance;
public EmbeddingData embeddingData;
public DistanceEmbedding(float distance, EmbeddingData embeddingData)
{
this.distance = distance;
this.embeddingData = embeddingData;
}
}
public class DistanceComparer : IComparer<DistanceEmbedding>
{
public int Compare(DistanceEmbedding x, DistanceEmbedding y)
{
return x.distance.CompareTo(y.distance);
}
}
/// <summary>
///
/// </summary>
public class SignPredictor : MonoBehaviour
{
/// <summary>
/// Predictor class which is used to predict the sign using an MLEdgeModel
/// </summary>
public class NatMLSignPredictorEmbed : IMLPredictor<List<float>>
{
/// <summary>
/// The MLEdgeModel used for predictions
/// </summary>
private readonly MLEdgeModel edgeModel;
/// <summary>
/// The type used to create features which are input for the model
/// </summary>
private MLFeatureType featureType;
/// <summary>
/// Creation of a NatMLSignPredictor instance
/// </summary>
/// <param name="edgeModel"></param>
public NatMLSignPredictorEmbed(MLEdgeModel edgeModel)
{
this.edgeModel = edgeModel;
featureType = edgeModel.inputs[0];
}
/// <summary>
/// Predicts the sign using the MLEdgeModel
/// </summary>
/// <param name="inputs"></param>
/// <returns></returns>
public List<float> Predict(params MLFeature[] inputs)
{
List<float> predictions = null;
IMLEdgeFeature iedgeFeature = (IMLEdgeFeature)inputs[0];
MLEdgeFeature edgeFeature = iedgeFeature.Create(featureType);
MLFeatureCollection<MLEdgeFeature> result = edgeModel.Predict(edgeFeature);
if (0 < result.Count)
{
predictions = new MLArrayFeature<float>(result[0]).Flatten().ToArray().ToList();
}
edgeFeature.Dispose();
result.Dispose();
return predictions;
}
/// <summary>
/// Disposing the MLEdgeModel
/// </summary>
public void Dispose()
{
edgeModel.Dispose();
}
}
/// <summary>
/// Predictor class which is used to predict the sign using an MLEdgeModel
/// </summary>
@@ -79,8 +170,11 @@ public class SignPredictor : MonoBehaviour
/// <summary>
/// Predictor which is used to create the asyncPredictor (should not be used if asyncPredictor exists)
/// </summary>
private NatMLSignPredictorEmbed predictor_embed;
private NatMLSignPredictor predictor;
/// <summary>
/// The asynchronous predictor which is used to predict the sign using an MLEdgemodel
/// </summary>
@@ -178,6 +272,11 @@ public class SignPredictor : MonoBehaviour
/// </summary>
private KeypointManager keypointManager;
/// <summary>
/// A keypointmanager which does normalization stuff, keeps track of the landmarks (for embedding model)
/// </summary>
private KeypointManagerEmbedding keypointManagerEmbedding;
/// <summary>
/// Width of th webcam
/// </summary>
@@ -198,6 +297,10 @@ public class SignPredictor : MonoBehaviour
/// </summary>
private static bool resourceManagerIsInitialized = false;
private EmbeddingDataList embeddingDataList;
private ModelIndex modelID;
/// <summary>
/// Google Mediapipe setup & run
/// </summary>
@@ -258,9 +361,6 @@ public class SignPredictor : MonoBehaviour
graph.StartRun().AssertOk();
stopwatch.Start();
// Creating a KeypointManager
keypointManager = new KeypointManager(modelInfoFile);
// Check if a model is ready to load
yield return new WaitUntil(() => modelList.HasValidModel());
@@ -268,12 +368,47 @@ public class SignPredictor : MonoBehaviour
Task<MLEdgeModel> t = Task.Run(() => MLEdgeModel.Create(modelList.GetCurrentModel()));
yield return new WaitUntil(() => t.IsCompleted);
model = t.Result;
predictor = new NatMLSignPredictor(model);
asyncPredictor = predictor.ToAsync();
// Start the Coroutine
StartCoroutine(SignRecognitionCoroutine());
StartCoroutine(MediapipeCoroutine());
modelID = modelList.GetCurrentModelIndex();
if (modelID == ModelIndex.FINGERSPELLING)
{
predictor = new NatMLSignPredictor(model);
asyncPredictor = predictor.ToAsync();
// Creating a KeypointManager
keypointManager = new KeypointManager(modelInfoFile);
StartCoroutine(SignRecognitionCoroutine());
StartCoroutine(MediapipeCoroutine());
}
else
{
predictor_embed = new NatMLSignPredictorEmbed(model);
asyncPredictor = predictor_embed.ToAsync();
// Creating a KeypointManager
keypointManagerEmbedding = new KeypointManagerEmbedding();
// read the embedding data
string filePath = Path.Combine(Application.dataPath, "Common/Models/BasicSigns", "embeddings.json");
if (File.Exists(filePath))
{
string jsonData = File.ReadAllText(filePath);
UnityEngine.Debug.Log(jsonData);
embeddingDataList = JsonUtility.FromJson<EmbeddingDataList>("{\"dataList\":" + jsonData + "}");
}
else
{
UnityEngine.Debug.LogError("File not found: " + filePath);
}
// Start the Coroutine
StartCoroutine(SignRecognitionCoroutineEmbed());
StartCoroutine(MediapipeCoroutineEmbed());
}
}
/// <summary>
@@ -302,6 +437,73 @@ public class SignPredictor : MonoBehaviour
}
}
/// <summary>
/// Coroutine which executes the mediapipe pipeline
/// </summary>
/// <returns></returns>
private IEnumerator MediapipeCoroutineEmbed()
{
while (true)
{
inputTexture.SetPixels32(webcamTexture.GetPixels32(pixelData));
var imageFrame = new ImageFrame(ImageFormat.Types.Format.Srgba, width, height, width * 4, inputTexture.GetRawTextureData<byte>());
var currentTimestamp = stopwatch.ElapsedTicks / (System.TimeSpan.TicksPerMillisecond / 1000);
graph.AddPacketToInputStream("input_video", new ImageFramePacket(imageFrame, new Timestamp(currentTimestamp))).AssertOk();
yield return new WaitForEndOfFrame();
NormalizedLandmarkList _poseLandmarks = null;
NormalizedLandmarkList _leftHandLandmarks = null;
NormalizedLandmarkList _rightHandLandmarks = null;
yield return new WaitUntil(() => { posestream.TryGetNext(out _poseLandmarks); return true; });
yield return new WaitUntil(() => { leftstream.TryGetNext(out _leftHandLandmarks); return true; });
yield return new WaitUntil(() => { rightstream.TryGetNext(out _rightHandLandmarks); return true; });
keypointManagerEmbedding.AddLandmarks(_poseLandmarks, _leftHandLandmarks, _rightHandLandmarks);
}
}
private float MinkowskiDistance(List<float> x, float[] y, int p)
{
int dimensions = x.Count;
float sum = 0;
for (int i = 0; i < dimensions; i++)
{
sum += Mathf.Pow(Mathf.Abs(x[i] - y[i]), p);
}
return Mathf.Pow(sum, 1.0f / p);
}
private List<DistanceEmbedding> GetDistances(List<float> embedding, int p = 2)
{
List<DistanceEmbedding> distances = new List<DistanceEmbedding>();
DistanceComparer comparer = new DistanceComparer();
foreach (EmbeddingData data in embeddingDataList.dataList)
{
float distance = MinkowskiDistance(embedding, data.embeddings, p);
DistanceEmbedding newDistanceEmbedding = new DistanceEmbedding(distance, data);
// Find the appropriate index to insert the new item to maintain the sorted order
int index = distances.BinarySearch(newDistanceEmbedding, comparer);
// If the index is negative, it represents the bitwise complement of the nearest larger element
if (index < 0)
{
index = ~index;
}
// Insert the new item at the appropriate position
distances.Insert(index, newDistanceEmbedding);
}
return distances;
}
/// <summary>
/// Coroutine which calls the sign predictor model
/// </summary>
@@ -345,7 +547,6 @@ public class SignPredictor : MonoBehaviour
{
learnableProbabilities.Add(signs[j].ToUpper(), result[j]);
}
//Debug.Log($"prob = [{learnableProbabilities.Aggregate(" ", (t, kv) => $"{t}{kv.Key}:{kv.Value} ")}]");
foreach (Listener listener in listeners)
{
yield return listener.ProcessIncomingCall();
@@ -363,6 +564,88 @@ public class SignPredictor : MonoBehaviour
}
/// <summary>
/// Coroutine which calls the sign predictor embedding model
/// </summary>
/// <returns></returns>
private IEnumerator SignRecognitionCoroutineEmbed()
{
while (true)
{
List<List<List<float>>> inputData = keypointManagerEmbedding.GetKeypoints();
if (inputData != null && asyncPredictor.readyForPrediction)
{
// Getting the size of the input data
int framecount = inputData.Count;
int keypointsPerFrame = inputData[0].Count;
// Creating ArrayFeature
int[] shape = { 1, framecount, keypointsPerFrame, 2 };
float[] input = new float[framecount * keypointsPerFrame * 2];
int i = 0;
inputData.ForEach((e) => e.ForEach((f) => f.ForEach((k) => input[i++] = k)));
MLArrayFeature<float> feature = new MLArrayFeature<float>(input, shape);
// Predicting
Task<List<float>> task = Task.Run(async () => await asyncPredictor.Predict(feature));
yield return new WaitUntil(() => task.IsCompleted);
List<float> result = task.Result;
if (0 < result.Count)
{
List<DistanceEmbedding> distances = GetDistances(result, 2);
learnableProbabilities = new Dictionary<string, float>();
for (int j = 0; j < distances.Count; j++)
{
DistanceEmbedding distanceEmbedding = distances[j];
// check if already in dictionary
if (learnableProbabilities.ContainsKey(distanceEmbedding.embeddingData.label_name))
{
// if so, check if the distance is smaller
if (learnableProbabilities[distanceEmbedding.embeddingData.label_name] > distanceEmbedding.distance)
{
// if so, replace the distance
learnableProbabilities[distanceEmbedding.embeddingData.label_name] = distanceEmbedding.distance;
}
}
else
{
// if not, add the distance to the dictionary
learnableProbabilities.Add(distanceEmbedding.embeddingData.label_name, distanceEmbedding.distance);
}
}
Dictionary<string, float> newLearnableProbabilities = new Dictionary<string, float>();
// convert distances to probabilities, the closer to 1.5 the better the prediction
foreach (KeyValuePair<string, float> entry in learnableProbabilities)
{
float probability = 1 / (1 + Mathf.Pow(2.71828f, (entry.Value - 1.6f) * 2));
newLearnableProbabilities.Add(entry.Key, probability);
}
learnableProbabilities = newLearnableProbabilities;
foreach (Listener listener in listeners)
{
yield return listener.ProcessIncomingCall();
}
}
}
yield return null;
}
}
/// <summary>
/// Propper destruction on the Mediapipegraph
/// </summary>