Merge remote-tracking branch 'origin/SpoterEmbedding' into demo-day

This commit is contained in:
Dries Van Schuylenbergh 2023-04-19 16:05:22 +02:00
commit db1a72fadd
15 changed files with 617 additions and 13 deletions

View File

@ -4,8 +4,9 @@ using UnityEngine;
/// <summary>
/// This enum is used to identify each of the SignLanguage models
/// </summary>
public enum ModelIndex
public enum ModelIndex
{
NONE,
FINGERSPELLING
FINGERSPELLING,
BASICSIGNS
}

View File

@ -74,4 +74,9 @@ public class ModelList : ScriptableObject
{
currentModelIndex = models.FindIndex((m) => m.index == index);
}
public ModelIndex GetCurrentModelIndex()
{
return models[currentModelIndex].index;
}
}

View File

@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: cadb927ad0f664463b8f5fef7146c561
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,10 @@
fileFormatVersion: 2
guid: 17fb70e1c284e44da8083b36bb6afcb8
ScriptedImporter:
internalIDToNameTable: []
externalObjects: {}
serializedVersion: 2
userData:
assetBundleName:
assetBundleVariant:
script: {fileID: 11500000, guid: 3e882272056fc4ddfa14de161aaba2ba, type: 3}

Binary file not shown.

View File

@ -0,0 +1,10 @@
fileFormatVersion: 2
guid: fa63c40c78ba548468cad97b15cdc6c9
ScriptedImporter:
internalIDToNameTable: []
externalObjects: {}
serializedVersion: 2
userData:
assetBundleName:
assetBundleVariant:
script: {fileID: 11500000, guid: 8264490bef67c46f2982e6dd3f5e46cd, type: 3}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,7 @@
fileFormatVersion: 2
guid: 4e303164823194bc4be87f4c9550cfd0
TextScriptImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,263 @@
using Mediapipe;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
public class KeypointManagerEmbedding
{
private int leftShoulderIndex = 11;
private int rightShoulderIndex = 12;
private int neckIndex = 33;
private int noseIndex = 0;
private int leftEyeIndex = 2;
private List<int> pose_indices = new List<int> { 0, 33, 5, 2, 8, 7, 12, 11, 14, 13, 16, 15 };
private List<int> hand_indices = new List<int> { 0, 8, 7, 6, 5, 12, 11, 10, 9, 16, 15, 14, 13, 20, 19, 18, 17, 4, 3, 2, 1 };
private static int BUFFER_SIZE = 10;
private List<List<List<float>>> keypointsBuffer;
public KeypointManagerEmbedding()
{
keypointsBuffer = new List<List<List<float>>>();
}
private (List<float>, List<float>) NormalizeHand(List<float> handX, List<float> handY)
{
var xValues = new List<float>();
var yValues = new List<float>();
for (int i = 0; i < handX.Count; i++)
{
if (handX[i] != 0)
{
xValues.Add(handX[i]);
}
if (handY[i] != 0)
{
yValues.Add(handY[i]);
}
}
if (xValues.Count == 0 || yValues.Count == 0)
{
return (handX, handY);
}
float width = xValues.Max() - xValues.Min();
float height = yValues.Max() - yValues.Min();
float delta_x, delta_y;
if (width > height)
{
delta_x = 0.1f * width;
delta_y = delta_x + ((width - height) / 2f);
}
else
{
delta_y = 0.1f * height;
delta_x = delta_y + ((height - width) / 2f);
}
var startingPoint = new Vector2(xValues.Min() - delta_x, yValues.Min() - delta_y);
var endingPoint = new Vector2(xValues.Max() + delta_x, yValues.Max() + delta_y);
if (endingPoint.x - startingPoint.x == 0f || endingPoint.y - startingPoint.y == 0f)
{
return (handX, handY);
}
// normalize keypoints
for (int i = 0; i < handX.Count; i++)
{
handX[i] = (handX[i] - startingPoint.x) / (endingPoint.x - startingPoint.x);
handY[i] = (handY[i] - startingPoint.y) / (endingPoint.y - startingPoint.y);
}
return (handX, handY);
}
private (List<float>, List<float>) NormalizePose(List<float> poseX, List<float> poseY)
{
var leftShoulder = new Vector2(poseX[leftShoulderIndex], poseY[leftShoulderIndex]);
var rightShoulder = new Vector2(poseX[rightShoulderIndex], poseY[rightShoulderIndex]);
var neck = new Vector2(poseX[neckIndex], poseY[neckIndex]);
var nose = new Vector2(poseX[noseIndex], poseY[noseIndex]);
// Prevent from even starting the analysis if some necessary elements are not present
if ((leftShoulder.x == 0 || rightShoulder.x == 0 ||
(leftShoulder.x == rightShoulder.x && leftShoulder.y == rightShoulder.y)) &&
(neck.x == 0 || nose.x == 0 || (neck.x == nose.x && neck.y == nose.y)))
{
return (poseX, poseY);
}
float shoulderDistance, headMetric;
if (leftShoulder.x != 0 && rightShoulder.x != 0 &&
(leftShoulder.x != rightShoulder.x || leftShoulder.y != rightShoulder.y))
{
shoulderDistance = Mathf.Sqrt(Mathf.Pow(leftShoulder.x - rightShoulder.x, 2) + Mathf.Pow(leftShoulder.y - rightShoulder.y, 2));
headMetric = shoulderDistance;
}
else
{
float neckNoseDistance = Mathf.Sqrt(Mathf.Pow(neck.x - nose.x, 2) + Mathf.Pow(neck.y - nose.y, 2));
headMetric = neckNoseDistance;
}
// Set the starting and ending point of the normalization bounding box
var startingPoint = new Vector2(poseX[neckIndex] - 3 * headMetric, poseY[leftEyeIndex] + headMetric);
var endingPoint = new Vector2(poseX[neckIndex] + 3 * headMetric, startingPoint.y - 6 * headMetric);
if (startingPoint.x < 0)
{
startingPoint.x = 0;
}
if (startingPoint.y < 0)
{
startingPoint.y = 0;
}
if (endingPoint.x < 0)
{
endingPoint.x = 0;
}
if (endingPoint.y < 0)
{
endingPoint.y = 0;
}
// Normalize the keypoints
for (int i = 0; i < poseX.Count; i++)
{
poseX[i] = (poseX[i] - startingPoint.x) / (endingPoint.x - startingPoint.x);
poseY[i] = (poseY[i] - endingPoint.y) / (startingPoint.y - endingPoint.y);
}
return (poseX, poseY);
}
private (List<float>, List<float>) CalculateNeck(List<float> keypointsX, List<float> keypointsY)
{
var leftShoulder = new Vector2(keypointsX[11], keypointsY[11]);
var rightShoulder = new Vector2(keypointsX[12], keypointsY[12]);
var neck = new Vector2((leftShoulder.x + rightShoulder.x) / 2, (leftShoulder.y + rightShoulder.y) / 2);
// add neck to keypoints
keypointsX.Add(neck.x);
keypointsY.Add(neck.y);
return (keypointsX, keypointsY);
}
public void AddLandmarks(NormalizedLandmarkList poseLandmarks, NormalizedLandmarkList leftHandLandmarks, NormalizedLandmarkList rightHandLandmarks)
{
List<float> pose_x = new List<float>();
List<float> pose_y = new List<float>();
List<float> left_hand_x = new List<float>();
List<float> left_hand_y = new List<float>();
List<float> right_hand_x = new List<float>();
List<float> right_hand_y = new List<float>();
if (poseLandmarks == null || (leftHandLandmarks == null && rightHandLandmarks == null))
{
return;
}
if (poseLandmarks != null)
{
foreach (NormalizedLandmark landmark in poseLandmarks.Landmark)
{
pose_x.Add(landmark.X);
pose_y.Add(landmark.Y);
}
}else{
for (int i = 0; i < 33; i++)
{
pose_x.Add(0);
pose_y.Add(0);
}
}
// Add neck to pose
(pose_x, pose_y) = CalculateNeck(pose_x, pose_y);
// normalize pose
(pose_x, pose_y) = NormalizePose(pose_x, pose_y);
// now filter the pose keypoints based on the pose indeces
List<List<float>> filtered_pose = new List<List<float>>();
foreach (int index in pose_indices)
{
filtered_pose.Add(new List<float> { pose_x[index] - 0.5f, pose_y[index] - 0.5f });
}
// add hand landmarks
if (leftHandLandmarks != null)
{
foreach (NormalizedLandmark landmark in leftHandLandmarks.Landmark)
{
left_hand_x.Add(landmark.X);
left_hand_y.Add(landmark.Y);
}
}else{
for (int i = 0; i < 21; i++)
{
left_hand_x.Add(0);
left_hand_y.Add(0);
}
}
if (rightHandLandmarks != null)
{
foreach (NormalizedLandmark landmark in rightHandLandmarks.Landmark)
{
right_hand_x.Add(landmark.X);
right_hand_y.Add(landmark.Y);
}
}else{
for (int i = 0; i < 21; i++)
{
right_hand_x.Add(0);
right_hand_y.Add(0);
}
}
// normalize the hands
(left_hand_x, left_hand_y) = NormalizeHand(left_hand_x, left_hand_y);
(right_hand_x, right_hand_y) = NormalizeHand(right_hand_x, right_hand_y);
// now filter the hand keypoints based on the hand indeces
List<List<float>> filtered_left_hand = new List<List<float>>();
List<List<float>> filtered_right_hand = new List<List<float>>();
foreach (int index in hand_indices)
{
filtered_left_hand.Add(new List<float> { left_hand_x[index] - 0.5f, left_hand_y[index] - 0.5f });
filtered_right_hand.Add(new List<float> { right_hand_x[index] - 0.5f, right_hand_y[index] - 0.5f });
}
// add the filtered keypoints together in one list
List<List<float>> filtered_keypoints = new List<List<float>>();
filtered_keypoints.AddRange(filtered_pose);
filtered_keypoints.AddRange(filtered_left_hand);
filtered_keypoints.AddRange(filtered_right_hand);
keypointsBuffer.Add(filtered_keypoints);
if (keypointsBuffer.Count > BUFFER_SIZE)
{
keypointsBuffer.RemoveAt(0);
}
}
public List<List<List<float>>> GetKeypoints()
{
if (keypointsBuffer.Count < BUFFER_SIZE){
return null;
}
return keypointsBuffer;
}
}

View File

@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 8978a7c17464d4fa8ab9f33be45a2bb6
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -19,4 +19,7 @@ MonoBehaviour:
- index: 1
modelWINDOWS: {fileID: 8538825877217656561, guid: fdbf401e965a6bf4a87637cd519f2715, type: 3}
modelMAC: {fileID: 8538825877217656561, guid: be31548ec7e7544fe9828b14657bb40b, type: 3}
currentModelIndex: 1
- index: 2
modelWINDOWS: {fileID: 8538825877217656561, guid: fa63c40c78ba548468cad97b15cdc6c9, type: 3}
modelMAC: {fileID: 8538825877217656561, guid: 17fb70e1c284e44da8083b36bb6afcb8, type: 3}
currentModelIndex: 2

View File

@ -10,12 +10,103 @@ using System.Linq;
using System.Threading.Tasks;
using UnityEngine;
using UnityEngine.UI;
using System.IO;
[System.Serializable]
public class EmbeddingData
{
public float[] embeddings;
public string label_name;
public int labels;
}
[System.Serializable]
public class EmbeddingDataList
{
public List<EmbeddingData> dataList;
}
public class DistanceEmbedding
{
public float distance;
public EmbeddingData embeddingData;
public DistanceEmbedding(float distance, EmbeddingData embeddingData)
{
this.distance = distance;
this.embeddingData = embeddingData;
}
}
public class DistanceComparer : IComparer<DistanceEmbedding>
{
public int Compare(DistanceEmbedding x, DistanceEmbedding y)
{
return x.distance.CompareTo(y.distance);
}
}
/// <summary>
///
/// </summary>
public class SignPredictor : MonoBehaviour
{
/// <summary>
/// Predictor class which is used to predict the sign using an MLEdgeModel
/// </summary>
public class NatMLSignPredictorEmbed : IMLPredictor<List<float>>
{
/// <summary>
/// The MLEdgeModel used for predictions
/// </summary>
private readonly MLEdgeModel edgeModel;
/// <summary>
/// The type used to create features which are input for the model
/// </summary>
private MLFeatureType featureType;
/// <summary>
/// Creation of a NatMLSignPredictor instance
/// </summary>
/// <param name="edgeModel"></param>
public NatMLSignPredictorEmbed(MLEdgeModel edgeModel)
{
this.edgeModel = edgeModel;
featureType = edgeModel.inputs[0];
}
/// <summary>
/// Predicts the sign using the MLEdgeModel
/// </summary>
/// <param name="inputs"></param>
/// <returns></returns>
public List<float> Predict(params MLFeature[] inputs)
{
List<float> predictions = null;
IMLEdgeFeature iedgeFeature = (IMLEdgeFeature)inputs[0];
MLEdgeFeature edgeFeature = iedgeFeature.Create(featureType);
MLFeatureCollection<MLEdgeFeature> result = edgeModel.Predict(edgeFeature);
if (0 < result.Count)
{
predictions = new MLArrayFeature<float>(result[0]).Flatten().ToArray().ToList();
}
edgeFeature.Dispose();
result.Dispose();
return predictions;
}
/// <summary>
/// Disposing the MLEdgeModel
/// </summary>
public void Dispose()
{
edgeModel.Dispose();
}
}
/// <summary>
/// Predictor class which is used to predict the sign using an MLEdgeModel
/// </summary>
@ -79,8 +170,11 @@ public class SignPredictor : MonoBehaviour
/// <summary>
/// Predictor which is used to create the asyncPredictor (should not be used if asyncPredictor exists)
/// </summary>
private NatMLSignPredictorEmbed predictor_embed;
private NatMLSignPredictor predictor;
/// <summary>
/// The asynchronous predictor which is used to predict the sign using an MLEdgemodel
/// </summary>
@ -178,6 +272,11 @@ public class SignPredictor : MonoBehaviour
/// </summary>
private KeypointManager keypointManager;
/// <summary>
/// A keypointmanager which does normalization stuff, keeps track of the landmarks (for embedding model)
/// </summary>
private KeypointManagerEmbedding keypointManagerEmbedding;
/// <summary>
/// Width of th webcam
/// </summary>
@ -198,6 +297,10 @@ public class SignPredictor : MonoBehaviour
/// </summary>
private static bool resourceManagerIsInitialized = false;
private EmbeddingDataList embeddingDataList;
private ModelIndex modelID;
/// <summary>
/// Google Mediapipe setup & run
/// </summary>
@ -258,9 +361,6 @@ public class SignPredictor : MonoBehaviour
graph.StartRun().AssertOk();
stopwatch.Start();
// Creating a KeypointManager
keypointManager = new KeypointManager(modelInfoFile);
// Check if a model is ready to load
yield return new WaitUntil(() => modelList.HasValidModel());
@ -268,12 +368,47 @@ public class SignPredictor : MonoBehaviour
Task<MLEdgeModel> t = Task.Run(() => MLEdgeModel.Create(modelList.GetCurrentModel()));
yield return new WaitUntil(() => t.IsCompleted);
model = t.Result;
predictor = new NatMLSignPredictor(model);
asyncPredictor = predictor.ToAsync();
// Start the Coroutine
StartCoroutine(SignRecognitionCoroutine());
StartCoroutine(MediapipeCoroutine());
modelID = modelList.GetCurrentModelIndex();
if (modelID == ModelIndex.FINGERSPELLING)
{
predictor = new NatMLSignPredictor(model);
asyncPredictor = predictor.ToAsync();
// Creating a KeypointManager
keypointManager = new KeypointManager(modelInfoFile);
StartCoroutine(SignRecognitionCoroutine());
StartCoroutine(MediapipeCoroutine());
}
else
{
predictor_embed = new NatMLSignPredictorEmbed(model);
asyncPredictor = predictor_embed.ToAsync();
// Creating a KeypointManager
keypointManagerEmbedding = new KeypointManagerEmbedding();
// read the embedding data
string filePath = Path.Combine(Application.dataPath, "Common/Models/BasicSigns", "embeddings.json");
if (File.Exists(filePath))
{
string jsonData = File.ReadAllText(filePath);
UnityEngine.Debug.Log(jsonData);
embeddingDataList = JsonUtility.FromJson<EmbeddingDataList>("{\"dataList\":" + jsonData + "}");
}
else
{
UnityEngine.Debug.LogError("File not found: " + filePath);
}
// Start the Coroutine
StartCoroutine(SignRecognitionCoroutineEmbed());
StartCoroutine(MediapipeCoroutineEmbed());
}
}
/// <summary>
@ -302,6 +437,73 @@ public class SignPredictor : MonoBehaviour
}
}
/// <summary>
/// Coroutine which executes the mediapipe pipeline
/// </summary>
/// <returns></returns>
private IEnumerator MediapipeCoroutineEmbed()
{
while (true)
{
inputTexture.SetPixels32(webcamTexture.GetPixels32(pixelData));
var imageFrame = new ImageFrame(ImageFormat.Types.Format.Srgba, width, height, width * 4, inputTexture.GetRawTextureData<byte>());
var currentTimestamp = stopwatch.ElapsedTicks / (System.TimeSpan.TicksPerMillisecond / 1000);
graph.AddPacketToInputStream("input_video", new ImageFramePacket(imageFrame, new Timestamp(currentTimestamp))).AssertOk();
yield return new WaitForEndOfFrame();
NormalizedLandmarkList _poseLandmarks = null;
NormalizedLandmarkList _leftHandLandmarks = null;
NormalizedLandmarkList _rightHandLandmarks = null;
yield return new WaitUntil(() => { posestream.TryGetNext(out _poseLandmarks); return true; });
yield return new WaitUntil(() => { leftstream.TryGetNext(out _leftHandLandmarks); return true; });
yield return new WaitUntil(() => { rightstream.TryGetNext(out _rightHandLandmarks); return true; });
keypointManagerEmbedding.AddLandmarks(_poseLandmarks, _leftHandLandmarks, _rightHandLandmarks);
}
}
private float MinkowskiDistance(List<float> x, float[] y, int p)
{
int dimensions = x.Count;
float sum = 0;
for (int i = 0; i < dimensions; i++)
{
sum += Mathf.Pow(Mathf.Abs(x[i] - y[i]), p);
}
return Mathf.Pow(sum, 1.0f / p);
}
private List<DistanceEmbedding> GetDistances(List<float> embedding, int p = 2)
{
List<DistanceEmbedding> distances = new List<DistanceEmbedding>();
DistanceComparer comparer = new DistanceComparer();
foreach (EmbeddingData data in embeddingDataList.dataList)
{
float distance = MinkowskiDistance(embedding, data.embeddings, p);
DistanceEmbedding newDistanceEmbedding = new DistanceEmbedding(distance, data);
// Find the appropriate index to insert the new item to maintain the sorted order
int index = distances.BinarySearch(newDistanceEmbedding, comparer);
// If the index is negative, it represents the bitwise complement of the nearest larger element
if (index < 0)
{
index = ~index;
}
// Insert the new item at the appropriate position
distances.Insert(index, newDistanceEmbedding);
}
return distances;
}
/// <summary>
/// Coroutine which calls the sign predictor model
/// </summary>
@ -345,7 +547,6 @@ public class SignPredictor : MonoBehaviour
{
learnableProbabilities.Add(signs[j].ToUpper(), result[j]);
}
//Debug.Log($"prob = [{learnableProbabilities.Aggregate(" ", (t, kv) => $"{t}{kv.Key}:{kv.Value} ")}]");
foreach (Listener listener in listeners)
{
yield return listener.ProcessIncomingCall();
@ -363,6 +564,88 @@ public class SignPredictor : MonoBehaviour
}
/// <summary>
/// Coroutine which calls the sign predictor embedding model
/// </summary>
/// <returns></returns>
private IEnumerator SignRecognitionCoroutineEmbed()
{
while (true)
{
List<List<List<float>>> inputData = keypointManagerEmbedding.GetKeypoints();
if (inputData != null && asyncPredictor.readyForPrediction)
{
// Getting the size of the input data
int framecount = inputData.Count;
int keypointsPerFrame = inputData[0].Count;
// Creating ArrayFeature
int[] shape = { 1, framecount, keypointsPerFrame, 2 };
float[] input = new float[framecount * keypointsPerFrame * 2];
int i = 0;
inputData.ForEach((e) => e.ForEach((f) => f.ForEach((k) => input[i++] = k)));
MLArrayFeature<float> feature = new MLArrayFeature<float>(input, shape);
// Predicting
Task<List<float>> task = Task.Run(async () => await asyncPredictor.Predict(feature));
yield return new WaitUntil(() => task.IsCompleted);
List<float> result = task.Result;
if (0 < result.Count)
{
List<DistanceEmbedding> distances = GetDistances(result, 2);
learnableProbabilities = new Dictionary<string, float>();
for (int j = 0; j < distances.Count; j++)
{
DistanceEmbedding distanceEmbedding = distances[j];
// check if already in dictionary
if (learnableProbabilities.ContainsKey(distanceEmbedding.embeddingData.label_name))
{
// if so, check if the distance is smaller
if (learnableProbabilities[distanceEmbedding.embeddingData.label_name] > distanceEmbedding.distance)
{
// if so, replace the distance
learnableProbabilities[distanceEmbedding.embeddingData.label_name] = distanceEmbedding.distance;
}
}
else
{
// if not, add the distance to the dictionary
learnableProbabilities.Add(distanceEmbedding.embeddingData.label_name, distanceEmbedding.distance);
}
}
Dictionary<string, float> newLearnableProbabilities = new Dictionary<string, float>();
// convert distances to probabilities, the closer to 1.5 the better the prediction
foreach (KeyValuePair<string, float> entry in learnableProbabilities)
{
float probability = 1 / (1 + Mathf.Pow(2.71828f, (entry.Value - 1.6f) * 2));
newLearnableProbabilities.Add(entry.Key, probability);
}
learnableProbabilities = newLearnableProbabilities;
foreach (Listener listener in listeners)
{
yield return listener.ProcessIncomingCall();
}
}
}
yield return null;
}
}
/// <summary>
/// Propper destruction on the Mediapipegraph
/// </summary>

View File

@ -5,6 +5,7 @@
public enum CourseIndex
{
FINGERSPELLING,
BASICSIGNS,
CLOTHING,
ANIMALS,
FOOD,

View File

@ -160,7 +160,8 @@ PlayerSettings:
- {fileID: 0}
- {fileID: 0}
- {fileID: 0}
- {fileID: 11400000, guid: 46a77681e9be442a9b3cceaa98c5d128, type: 2}
- {fileID: 0}
- {fileID: 0}
metroInputSource: 0
wsaTransparentSwapchain: 0
m_HolographicPauseOnTrackingLoss: 1