Demo day booth
This commit is contained in:
committed by
Jelle De Geest
parent
5b4a3ec4e7
commit
fcd8acad1e
@@ -11,11 +11,101 @@ using System.Threading.Tasks;
|
||||
using UnityEngine;
|
||||
using UnityEngine.UI;
|
||||
|
||||
|
||||
[System.Serializable]
|
||||
public class EmbeddingData
|
||||
{
|
||||
public float[] embeddings;
|
||||
public string label_name;
|
||||
public int labels;
|
||||
}
|
||||
|
||||
[System.Serializable]
|
||||
public class EmbeddingDataList
|
||||
{
|
||||
public List<EmbeddingData> dataList;
|
||||
}
|
||||
|
||||
public class DistanceEmbedding
|
||||
{
|
||||
public float distance;
|
||||
public EmbeddingData embeddingData;
|
||||
|
||||
public DistanceEmbedding(float distance, EmbeddingData embeddingData)
|
||||
{
|
||||
this.distance = distance;
|
||||
this.embeddingData = embeddingData;
|
||||
}
|
||||
}
|
||||
|
||||
public class DistanceComparer : IComparer<DistanceEmbedding>
|
||||
{
|
||||
public int Compare(DistanceEmbedding x, DistanceEmbedding y)
|
||||
{
|
||||
return x.distance.CompareTo(y.distance);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
public class SignPredictor : MonoBehaviour
|
||||
{
|
||||
/// <summary>
|
||||
/// Predictor class which is used to predict the sign using an MLEdgeModel
|
||||
/// </summary>
|
||||
public class NatMLSignPredictorEmbed : IMLPredictor<List<float>>
|
||||
{
|
||||
/// <summary>
|
||||
/// The MLEdgeModel used for predictions
|
||||
/// </summary>
|
||||
private readonly MLEdgeModel edgeModel;
|
||||
|
||||
/// <summary>
|
||||
/// The type used to create features which are input for the model
|
||||
/// </summary>
|
||||
private MLFeatureType featureType;
|
||||
|
||||
/// <summary>
|
||||
/// Creation of a NatMLSignPredictor instance
|
||||
/// </summary>
|
||||
/// <param name="edgeModel"></param>
|
||||
public NatMLSignPredictorEmbed(MLEdgeModel edgeModel)
|
||||
{
|
||||
this.edgeModel = edgeModel;
|
||||
featureType = edgeModel.inputs[0];
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predicts the sign using the MLEdgeModel
|
||||
/// </summary>
|
||||
/// <param name="inputs"></param>
|
||||
/// <returns></returns>
|
||||
public List<float> Predict(params MLFeature[] inputs)
|
||||
{
|
||||
List<float> predictions = null;
|
||||
IMLEdgeFeature iedgeFeature = (IMLEdgeFeature)inputs[0];
|
||||
MLEdgeFeature edgeFeature = iedgeFeature.Create(featureType);
|
||||
MLFeatureCollection<MLEdgeFeature> result = edgeModel.Predict(edgeFeature);
|
||||
if (0 < result.Count)
|
||||
{
|
||||
predictions = new MLArrayFeature<float>(result[0]).Flatten().ToArray().ToList();
|
||||
}
|
||||
edgeFeature.Dispose();
|
||||
result.Dispose();
|
||||
return predictions;
|
||||
}
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Disposing the MLEdgeModel
|
||||
/// </summary>
|
||||
public void Dispose()
|
||||
{
|
||||
edgeModel.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predictor class which is used to predict the sign using an MLEdgeModel
|
||||
/// </summary>
|
||||
@@ -79,8 +169,11 @@ public class SignPredictor : MonoBehaviour
|
||||
/// <summary>
|
||||
/// Predictor which is used to create the asyncPredictor (should not be used if asyncPredictor exists)
|
||||
/// </summary>
|
||||
private NatMLSignPredictorEmbed predictor_embed;
|
||||
|
||||
private NatMLSignPredictor predictor;
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// The asynchronous predictor which is used to predict the sign using an MLEdgemodel
|
||||
/// </summary>
|
||||
@@ -105,6 +198,7 @@ public class SignPredictor : MonoBehaviour
|
||||
/// Reference to the model info file
|
||||
/// </summary>
|
||||
public TextAsset modelInfoFile;
|
||||
public TextAsset modelInfoFileEmbedding;
|
||||
|
||||
/// <summary>
|
||||
/// Config file to set up the graph
|
||||
@@ -178,6 +272,11 @@ public class SignPredictor : MonoBehaviour
|
||||
/// </summary>
|
||||
private KeypointManager keypointManager;
|
||||
|
||||
/// <summary>
|
||||
/// A keypointmanager which does normalization stuff, keeps track of the landmarks (for embedding model)
|
||||
/// </summary>
|
||||
private KeypointManagerEmbedding keypointManagerEmbedding;
|
||||
|
||||
/// <summary>
|
||||
/// Width of th webcam
|
||||
/// </summary>
|
||||
@@ -198,6 +297,10 @@ public class SignPredictor : MonoBehaviour
|
||||
/// </summary>
|
||||
private static bool resourceManagerIsInitialized = false;
|
||||
|
||||
private EmbeddingDataList embeddingDataList;
|
||||
|
||||
private ModelIndex modelID;
|
||||
|
||||
/// <summary>
|
||||
/// Google Mediapipe setup & run
|
||||
/// </summary>
|
||||
@@ -258,9 +361,6 @@ public class SignPredictor : MonoBehaviour
|
||||
graph.StartRun().AssertOk();
|
||||
stopwatch.Start();
|
||||
|
||||
// Creating a KeypointManager
|
||||
keypointManager = new KeypointManager(modelInfoFile);
|
||||
|
||||
// Check if a model is ready to load
|
||||
yield return new WaitUntil(() => modelList.HasValidModel());
|
||||
|
||||
@@ -268,12 +368,34 @@ public class SignPredictor : MonoBehaviour
|
||||
Task<MLEdgeModel> t = Task.Run(() => MLEdgeModel.Create(modelList.GetCurrentModel()));
|
||||
yield return new WaitUntil(() => t.IsCompleted);
|
||||
model = t.Result;
|
||||
predictor = new NatMLSignPredictor(model);
|
||||
asyncPredictor = predictor.ToAsync();
|
||||
|
||||
// Start the Coroutine
|
||||
StartCoroutine(SignRecognitionCoroutine());
|
||||
StartCoroutine(MediapipeCoroutine());
|
||||
modelID = modelList.GetCurrentModelIndex();
|
||||
|
||||
if (modelID == ModelIndex.FINGERSPELLING)
|
||||
{
|
||||
predictor = new NatMLSignPredictor(model);
|
||||
asyncPredictor = predictor.ToAsync();
|
||||
// Creating a KeypointManager
|
||||
keypointManager = new KeypointManager(modelInfoFile);
|
||||
|
||||
StartCoroutine(SignRecognitionCoroutine());
|
||||
StartCoroutine(MediapipeCoroutine());
|
||||
}
|
||||
else
|
||||
{
|
||||
predictor_embed = new NatMLSignPredictorEmbed(model);
|
||||
asyncPredictor = predictor_embed.ToAsync();
|
||||
// Creating a KeypointManager
|
||||
keypointManagerEmbedding = new KeypointManagerEmbedding();
|
||||
|
||||
// read the embedding data
|
||||
embeddingDataList = JsonUtility.FromJson<EmbeddingDataList>($"{{\"dataList\":{modelInfoFileEmbedding}}}");
|
||||
// Start the Coroutine
|
||||
StartCoroutine(SignRecognitionCoroutineEmbed());
|
||||
StartCoroutine(MediapipeCoroutineEmbed());
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -302,6 +424,73 @@ public class SignPredictor : MonoBehaviour
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Coroutine which executes the mediapipe pipeline
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
private IEnumerator MediapipeCoroutineEmbed()
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
inputTexture.SetPixels32(webcamTexture.GetPixels32(pixelData));
|
||||
var imageFrame = new ImageFrame(ImageFormat.Types.Format.Srgba, width, height, width * 4, inputTexture.GetRawTextureData<byte>());
|
||||
var currentTimestamp = stopwatch.ElapsedTicks / (System.TimeSpan.TicksPerMillisecond / 1000);
|
||||
graph.AddPacketToInputStream("input_video", new ImageFramePacket(imageFrame, new Timestamp(currentTimestamp))).AssertOk();
|
||||
yield return new WaitForEndOfFrame();
|
||||
|
||||
NormalizedLandmarkList _poseLandmarks = null;
|
||||
NormalizedLandmarkList _leftHandLandmarks = null;
|
||||
NormalizedLandmarkList _rightHandLandmarks = null;
|
||||
|
||||
yield return new WaitUntil(() => { posestream.TryGetNext(out _poseLandmarks); return true; });
|
||||
yield return new WaitUntil(() => { leftstream.TryGetNext(out _leftHandLandmarks); return true; });
|
||||
yield return new WaitUntil(() => { rightstream.TryGetNext(out _rightHandLandmarks); return true; });
|
||||
|
||||
keypointManagerEmbedding.AddLandmarks(_poseLandmarks, _leftHandLandmarks, _rightHandLandmarks);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private float MinkowskiDistance(List<float> x, float[] y, int p)
|
||||
{
|
||||
int dimensions = x.Count;
|
||||
float sum = 0;
|
||||
|
||||
for (int i = 0; i < dimensions; i++)
|
||||
{
|
||||
sum += Mathf.Pow(Mathf.Abs(x[i] - y[i]), p);
|
||||
}
|
||||
|
||||
return Mathf.Pow(sum, 1.0f / p);
|
||||
}
|
||||
|
||||
private List<DistanceEmbedding> GetDistances(List<float> embedding, int p = 2)
|
||||
{
|
||||
List<DistanceEmbedding> distances = new List<DistanceEmbedding>();
|
||||
DistanceComparer comparer = new DistanceComparer();
|
||||
|
||||
foreach (EmbeddingData data in embeddingDataList.dataList)
|
||||
{
|
||||
float distance = MinkowskiDistance(embedding, data.embeddings, p);
|
||||
|
||||
DistanceEmbedding newDistanceEmbedding = new DistanceEmbedding(distance, data);
|
||||
|
||||
// Find the appropriate index to insert the new item to maintain the sorted order
|
||||
int index = distances.BinarySearch(newDistanceEmbedding, comparer);
|
||||
|
||||
// If the index is negative, it represents the bitwise complement of the nearest larger element
|
||||
if (index < 0)
|
||||
{
|
||||
index = ~index;
|
||||
}
|
||||
|
||||
// Insert the new item at the appropriate position
|
||||
distances.Insert(index, newDistanceEmbedding);
|
||||
}
|
||||
|
||||
return distances;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Coroutine which calls the sign predictor model
|
||||
/// </summary>
|
||||
@@ -345,7 +534,6 @@ public class SignPredictor : MonoBehaviour
|
||||
{
|
||||
learnableProbabilities.Add(signs[j].ToUpper(), result[j]);
|
||||
}
|
||||
//Debug.Log($"prob = [{learnableProbabilities.Aggregate(" ", (t, kv) => $"{t}{kv.Key}:{kv.Value} ")}]");
|
||||
foreach (Listener listener in listeners)
|
||||
{
|
||||
yield return listener.ProcessIncomingCall();
|
||||
@@ -363,6 +551,93 @@ public class SignPredictor : MonoBehaviour
|
||||
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Coroutine which calls the sign predictor embedding model
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
private IEnumerator SignRecognitionCoroutineEmbed()
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
List<List<List<float>>> inputData = keypointManagerEmbedding.GetKeypoints();
|
||||
if (inputData != null && asyncPredictor.readyForPrediction)
|
||||
{
|
||||
// Getting the size of the input data
|
||||
int framecount = inputData.Count;
|
||||
int keypointsPerFrame = inputData[0].Count;
|
||||
|
||||
// Creating ArrayFeature
|
||||
int[] shape = { 1, framecount, keypointsPerFrame, 2 };
|
||||
float[] input = new float[framecount * keypointsPerFrame * 2];
|
||||
|
||||
int i = 0;
|
||||
|
||||
inputData.ForEach((e) => e.ForEach((f) => f.ForEach((k) => input[i++] = k)));
|
||||
|
||||
MLArrayFeature<float> feature = new MLArrayFeature<float>(input, shape);
|
||||
|
||||
// Predicting
|
||||
Task<List<float>> task = Task.Run(async () => await asyncPredictor.Predict(feature));
|
||||
|
||||
|
||||
yield return new WaitUntil(() => task.IsCompleted);
|
||||
|
||||
List<float> result = task.Result;
|
||||
if (0 < result.Count)
|
||||
{
|
||||
List<DistanceEmbedding> distances = GetDistances(result, 2);
|
||||
|
||||
var probs = new Dictionary<string, float>();
|
||||
|
||||
for (int j = 0; j < distances.Count; j++)
|
||||
{
|
||||
DistanceEmbedding distanceEmbedding = distances[j];
|
||||
// check if already in dictionary
|
||||
if (probs.ContainsKey(distanceEmbedding.embeddingData.label_name))
|
||||
{
|
||||
// if so, check if the distance is smaller
|
||||
if (probs[distanceEmbedding.embeddingData.label_name] > distanceEmbedding.distance)
|
||||
{
|
||||
// if so, replace the distance
|
||||
probs[distanceEmbedding.embeddingData.label_name] = distanceEmbedding.distance;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// if not, add the distance to the dictionary
|
||||
probs.Add(distanceEmbedding.embeddingData.label_name, distanceEmbedding.distance);
|
||||
}
|
||||
}
|
||||
|
||||
// convert distances to probabilities, the closer to 1.5 the better the prediction
|
||||
var newProbs = new Dictionary<string, float>();
|
||||
float sum = 0.0f;
|
||||
foreach (KeyValuePair<string, float> entry in probs)
|
||||
{
|
||||
float probability = 1 / (1 + Mathf.Exp(2 * (entry.Value - 1.85f)));
|
||||
newProbs.Add(entry.Key, probability);
|
||||
sum += probability;
|
||||
}
|
||||
|
||||
learnableProbabilities = new Dictionary<string, float>();
|
||||
foreach (var kv in newProbs)
|
||||
learnableProbabilities.Add(kv.Key, kv.Value / sum);
|
||||
|
||||
//UnityEngine.Debug.Log($"{learnableProbabilities.Aggregate("", (t, e) => $"{t}{e.Key}={e.Value}, ")}");
|
||||
|
||||
foreach (Listener listener in listeners)
|
||||
{
|
||||
yield return listener.ProcessIncomingCall();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
yield return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Propper destruction on the Mediapipegraph
|
||||
/// </summary>
|
||||
|
||||
Reference in New Issue
Block a user