Files
unity-application/Assets/MediaPipeUnity/Scripts/SignPredictor.cs
2023-05-16 11:09:41 +00:00

647 lines
23 KiB
C#

using Mediapipe;
using Mediapipe.Unity;
using NatML;
using NatML.Features;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading.Tasks;
using UnityEngine;
using UnityEngine.UI;
/// <summary>
/// Class for EmbeddingData, Embeddings are used in the Model to make prediction
/// </summary>
[System.Serializable]
public class EmbeddingData
{
public float[] embeddings;
public string label_name;
public int labels;
}
/// <summary>
/// Class for a list of EmbeddingData
/// </summary>
[System.Serializable]
public class EmbeddingDataList
{
public List<EmbeddingData> dataList;
}
/// <summary>
/// Class to save the distance of an embedding
/// </summary>
public class DistanceEmbedding
{
public float distance;
public EmbeddingData embeddingData;
/// <summary>
/// Creation of DistanceEmbedding
/// </summary>
/// <param name="distance"></param>
/// <param name="embeddingData"></param>
public DistanceEmbedding(float distance, EmbeddingData embeddingData)
{
this.distance = distance;
this.embeddingData = embeddingData;
}
}
/// <summary>
/// Class to compare the distance of two embeddings
/// </summary>
public class DistanceComparer : IComparer<DistanceEmbedding>
{
/// <summary>
/// Function to compare the distance of two DistanceEmbeddings
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <returns></returns>
public int Compare(DistanceEmbedding x, DistanceEmbedding y)
{
return x.distance.CompareTo(y.distance);
}
}
/// <summary>
/// Sign predictor class which gives input to the games/courses by extractin information from the webcam
/// </summary>
public class SignPredictor : MonoBehaviour
{
/// <summary>
/// Predictor class which is used to predict the sign using an MLEdgeModel
/// </summary>
public class NatMLSignPredictorEmbed : IMLPredictor<List<float>>
{
/// <summary>
/// The MLEdgeModel used for predictions
/// </summary>
private readonly MLEdgeModel edgeModel;
/// <summary>
/// The type used to create features which are input for the model
/// </summary>
private MLFeatureType featureType;
/// <summary>
/// Creation of a NatMLSignPredictor instance
/// </summary>
/// <param name="edgeModel"></param>
public NatMLSignPredictorEmbed(MLEdgeModel edgeModel)
{
this.edgeModel = edgeModel;
featureType = edgeModel.inputs[0];
}
/// <summary>
/// Predicts the sign using the MLEdgeModel
/// </summary>
/// <param name="inputs"></param>
/// <returns></returns>
public List<float> Predict(params MLFeature[] inputs)
{
List<float> predictions = null;
IMLEdgeFeature iedgeFeature = (IMLEdgeFeature)inputs[0];
MLEdgeFeature edgeFeature = iedgeFeature.Create(featureType);
MLFeatureCollection<MLEdgeFeature> result = edgeModel.Predict(edgeFeature);
if (0 < result.Count)
{
predictions = new MLArrayFeature<float>(result[0]).Flatten().ToArray().ToList();
}
edgeFeature.Dispose();
result.Dispose();
return predictions;
}
/// <summary>
/// Disposing the MLEdgeModel
/// </summary>
public void Dispose()
{
edgeModel.Dispose();
}
}
/// <summary>
/// Predictor class which is used to predict the sign using an MLEdgeModel
/// </summary>
public class NatMLSignPredictor : IMLPredictor<List<float>>
{
/// <summary>
/// The MLEdgeModel used for predictions
/// </summary>
private readonly MLEdgeModel edgeModel;
/// <summary>
/// The type used to create features which are input for the model
/// </summary>
private MLFeatureType featureType;
/// <summary>
/// Creation of a NatMLSignPredictor instance
/// </summary>
/// <param name="edgeModel"></param>
public NatMLSignPredictor(MLEdgeModel edgeModel)
{
this.edgeModel = edgeModel;
featureType = edgeModel.inputs[0];
}
/// <summary>
/// Predicts the sign using the MLEdgeModel
/// </summary>
/// <param name="inputs"></param>
/// <returns></returns>
public List<float> Predict(params MLFeature[] inputs)
{
List<float> predictions = null;
IMLEdgeFeature iedgeFeature = (IMLEdgeFeature)inputs[0];
MLEdgeFeature edgeFeature = iedgeFeature.Create(featureType);
MLFeatureCollection<MLEdgeFeature> result = edgeModel.Predict(edgeFeature);
if (0 < result.Count)
{
predictions = new MLArrayFeature<float>(result[0]).Flatten().ToArray().ToList();
predictions = predictions.ConvertAll((c) => Mathf.Exp(c));
float sum = predictions.Sum();
predictions = predictions.ConvertAll((c) => c / sum);
}
edgeFeature.Dispose();
result.Dispose();
return predictions;
}
/// <summary>
/// Disposing the MLEdgeModel
/// </summary>
public void Dispose()
{
edgeModel.Dispose();
}
}
/// <summary>
/// List of listeners that want to get notified on new predictions
/// </summary>
public List<Listener> listeners = new List<Listener>();
/// <summary>
/// Predictor which is used to create the asyncPredictor (should not be used if asyncPredictor exists)
/// </summary>
private NatMLSignPredictorEmbed predictor_embed;
/// <summary>
/// The asynchronous predictor which is used to predict the sign using an MLEdgemodel
/// </summary>
private MLAsyncPredictor<List<float>> asyncPredictor;
/// <summary>
/// Reference to the model used in the SignPredictor
/// </summary>
private MLEdgeModel model;
/// <summary>
/// Modellist used to change model using ModelIndex
/// </summary>
public ModelList modelList;
/// <summary>
/// Chosen model data based on the operating system
/// </summary>
private MLModelData modelData;
/// <summary>
/// Reference to the model info file
/// </summary>
public TextAsset modelInfoFile;
/// <summary>
/// Reference to the model Embedding file
/// </summary>
public TextAsset modelInfoFileEmbedding;
/// <summary>
/// Config file to set up the graph
/// </summary>
[SerializeField]
private TextAsset configAsset;
/// <summary>
/// Index to indicate which camera is being used
/// </summary>
private int camdex = 0;
/// <summary>
/// The screen object on which the video is displayed
/// </summary>
[SerializeField]
private RawImage screen;
/// <summary>
/// MediaPipe graph
/// </summary>
private CalculatorGraph graph;
/// <summary>
/// Resource manager for graph resources
/// </summary>
private ResourceManager resourceManager;
/// <summary>
/// Webcam texture
/// </summary>
private WebCamTexture webcamTexture = null;
/// <summary>
/// Input texture
/// </summary>
private Texture2D inputTexture;
/// <summary>
/// Screen pixel data
/// </summary>
private Color32[] pixelData;
/// <summary>
/// Stopwatch to give a timestamp to video frames
/// </summary>
private Stopwatch stopwatch;
/// <summary>
/// The mediapipe stream which contains the pose landmarks
/// </summary>
private OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList> posestream;
/// <summary>
/// The mediapipe stream which contains the left hand landmarks
/// </summary>
private OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList> leftstream;
/// <summary>
/// The mediapipe stream which contains the right hand landmarks
/// </summary>
private OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList> rightstream;
/// <summary>
/// create precense stream
/// </summary>
public OutputStream<DetectionVectorPacket, List<Detection>> presenceStream;
/// <summary>
/// A keypointmanager which does normalization stuff, keeps track of the landmarks
/// </summary>
private KeypointManager keypointManager;
/// <summary>
/// A keypointmanager which does normalization stuff, keeps track of the landmarks (for embedding model)
/// </summary>
private KeypointManagerEmbedding keypointManagerEmbedding;
/// <summary>
/// Width of th webcam
/// </summary>
private int width;
/// <summary>
/// Height of the webcam
/// </summary>
private int height;
/// <summary>
/// The prediction of the sign predictor model
/// </summary>
public Dictionary<string, float> learnableProbabilities;
/// <summary>
/// Bool indicating whether or not the resource manager has already been initialized
/// </summary>
private static bool resourceManagerIsInitialized = false;
/// <summary>
/// List of the EmbeddingData
/// </summary>
private EmbeddingDataList embeddingDataList;
/// <summary>
/// Google Mediapipe setup & run
/// </summary>
/// <returns>IEnumerator</returns>
/// <exception cref="System.Exception"></exception>
private IEnumerator Start()
{
// Webcam setup
if (WebCamTexture.devices.Length == 0)
{
throw new System.Exception("Web Camera devices are not found");
}
// Start the webcam
WebCamDevice webCamDevice = WebCamTexture.devices[0];
webcamTexture = new WebCamTexture(webCamDevice.name);
webcamTexture.Play();
yield return new WaitUntil(() => webcamTexture.width > 16);
// Set webcam aspect ratio
width = webcamTexture.width;
height = webcamTexture.height;
float webcamAspect = (float)webcamTexture.width / (float)webcamTexture.height;
screen.rectTransform.sizeDelta = new Vector2(screen.rectTransform.sizeDelta.y * webcamAspect, (screen.rectTransform.sizeDelta.y));
screen.texture = webcamTexture;
// TODO this method is kinda meh you should use
inputTexture = new Texture2D(width, height, TextureFormat.RGBA32, false);
pixelData = new Color32[width * height];
if (!resourceManagerIsInitialized)
{
resourceManager = new StreamingAssetsResourceManager();
yield return resourceManager.PrepareAssetAsync("pose_detection.bytes");
yield return resourceManager.PrepareAssetAsync("pose_landmark_full.bytes");
yield return resourceManager.PrepareAssetAsync("face_landmark.bytes");
yield return resourceManager.PrepareAssetAsync("hand_landmark_full.bytes");
yield return resourceManager.PrepareAssetAsync("face_detection_short_range.bytes");
yield return resourceManager.PrepareAssetAsync("hand_recrop.bytes");
yield return resourceManager.PrepareAssetAsync("handedness.txt");
resourceManagerIsInitialized = true;
}
stopwatch = new Stopwatch();
// Setting up the graph
graph = new CalculatorGraph(configAsset.text);
posestream = new OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList>(graph, "pose_landmarks", "pose_landmarks_presence");
leftstream = new OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList>(graph, "left_hand_landmarks", "left_hand_landmarks_presence");
rightstream = new OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList>(graph, "right_hand_landmarks", "right_hand_landmarks_presence");
posestream.StartPolling().AssertOk();
leftstream.StartPolling().AssertOk();
rightstream.StartPolling().AssertOk();
graph.StartRun().AssertOk();
stopwatch.Start();
// Check if a model is ready to load
yield return new WaitUntil(() => modelList.HasValidModel());
NatML.MLEdgeModel.Configuration myConfig = null;
// Create Configuration
if (PersistentDataController.GetInstance().IsUsingGPU())
{
// Create a new instance of the Configuration class
myConfig = new NatML.MLEdgeModel.Configuration();
// Set the computeTarget property to GPU
myConfig.computeTarget = NatML.MLEdgeModel.ComputeTarget.GPU;
}
Task<MLEdgeModel> t = Task.Run(() => MLEdgeModel.Create(modelList.GetCurrentModel(), myConfig));
yield return new WaitUntil(() => t.IsCompleted);
model = t.Result;
predictor_embed = new NatMLSignPredictorEmbed(model);
asyncPredictor = predictor_embed.ToAsync();
// Creating a KeypointManager
keypointManagerEmbedding = new KeypointManagerEmbedding();
// read the embedding data
embeddingDataList = JsonUtility.FromJson<EmbeddingDataList>($"{{\"dataList\":{modelList.GetEmbeddings()}}}");
// Start the Coroutine
StartCoroutine(SignRecognitionCoroutineEmbed());
StartCoroutine(MediapipeCoroutineEmbed());
}
/// <summary>
/// Coroutine which executes the mediapipe pipeline
/// </summary>
/// <returns></returns>
private IEnumerator MediapipeCoroutineEmbed()
{
while (true)
{
inputTexture.SetPixels32(webcamTexture.GetPixels32(pixelData));
var imageFrame = new ImageFrame(ImageFormat.Types.Format.Srgba, width, height, width * 4, inputTexture.GetRawTextureData<byte>());
var currentTimestamp = stopwatch.ElapsedTicks / (System.TimeSpan.TicksPerMillisecond / 1000);
graph.AddPacketToInputStream("input_video", new ImageFramePacket(imageFrame, new Timestamp(currentTimestamp))).AssertOk();
yield return new WaitForEndOfFrame();
NormalizedLandmarkList _poseLandmarks = null;
NormalizedLandmarkList _leftHandLandmarks = null;
NormalizedLandmarkList _rightHandLandmarks = null;
yield return new WaitUntil(() => { posestream.TryGetNext(out _poseLandmarks); return true; });
yield return new WaitUntil(() => { leftstream.TryGetNext(out _leftHandLandmarks); return true; });
yield return new WaitUntil(() => { rightstream.TryGetNext(out _rightHandLandmarks); return true; });
keypointManagerEmbedding.AddLandmarks(_poseLandmarks, _leftHandLandmarks, _rightHandLandmarks);
}
}
/// <summary>
/// This function calculates the Minkowski distance between two points in a p-dimensional space.
/// The Minkowski distance is a generalization of Euclidean and Manhattan distances, depending on the value of p.
/// </summary>
/// <param name="x">List of float values representing the coordinates of the first point.</param>
/// <param name="y">Array of float values representing the coordinates of the second point.</param>
/// <param name="p">Integer value representing the power parameter of the Minkowski distance. When p=2, it calculates Euclidean distance, and when p=1, it calculates Manhattan distance.</param>
/// <returns>Returns the Minkowski distance between two points x and y in a p-dimensional space.</returns>
private float MinkowskiDistance(List<float> x, float[] y, int p)
{
int dimensions = x.Count;
float sum = 0;
for (int i = 0; i < dimensions; i++)
{
sum += Mathf.Pow(Mathf.Abs(x[i] - y[i]), p);
}
return Mathf.Pow(sum, 1.0f / p);
}
/// <summary>
/// This function calculates the Minkowski distances between a given embedding and a list of predefined embeddings.
/// The function returns a sorted list of distance and associated embedding data, sorted in ascending order of distance.
/// </summary>
/// <param name="embedding">A list of float values representing the embedding for which distances to all embeddings in the dataList need to be computed.</param>
/// <param name="p">An optional integer parameter representing the power parameter of the Minkowski distance. Defaults to 2, implying Euclidean distance is calculated if not specified.</param>
/// <returns>Returns a list of DistanceEmbedding objects, each representing the distance between the given embedding and an embedding from the dataList, along with the associated EmbeddingData. The list is sorted in ascending order of distance.</returns>
private List<DistanceEmbedding> GetDistances(List<float> embedding, int p = 2)
{
List<DistanceEmbedding> distances = new List<DistanceEmbedding>();
DistanceComparer comparer = new DistanceComparer();
foreach (EmbeddingData data in embeddingDataList.dataList)
{
float distance = MinkowskiDistance(embedding, data.embeddings, p);
DistanceEmbedding newDistanceEmbedding = new DistanceEmbedding(distance, data);
// Find the appropriate index to insert the new item to maintain the sorted order
int index = distances.BinarySearch(newDistanceEmbedding, comparer);
// If the index is negative, it represents the bitwise complement of the nearest larger element
if (index < 0)
{
index = ~index;
}
// Insert the new item at the appropriate position
distances.Insert(index, newDistanceEmbedding);
}
return distances;
}
/// <summary>
/// Coroutine which calls the sign predictor embedding model
/// </summary>
/// <returns></returns>
private IEnumerator SignRecognitionCoroutineEmbed()
{
while (true)
{
List<List<List<float>>> inputData = keypointManagerEmbedding.GetKeypoints();
if (inputData != null && asyncPredictor.readyForPrediction)
{
// Getting the size of the input data
int framecount = inputData.Count;
int keypointsPerFrame = inputData[0].Count;
// Creating ArrayFeature
int[] shape = { 1, framecount, keypointsPerFrame, 2 };
float[] input = new float[framecount * keypointsPerFrame * 2];
int i = 0;
inputData.ForEach((e) => e.ForEach((f) => f.ForEach((k) => input[i++] = k)));
MLArrayFeature<float> feature = new MLArrayFeature<float>(input, shape);
// Predicting
Task<List<float>> task = Task.Run(async () => await asyncPredictor.Predict(feature));
yield return new WaitUntil(() => task.IsCompleted);
List<float> result = task.Result;
if (result.Count > 0)
{
List<DistanceEmbedding> distances = GetDistances(result, 2);
learnableProbabilities = new Dictionary<string, float>();
for (int j = 0; j < distances.Count; j++)
{
DistanceEmbedding distanceEmbedding = distances[j];
// check if already in dictionary
if (learnableProbabilities.ContainsKey(distanceEmbedding.embeddingData.label_name))
{
// if so, check if the distance is smaller
if (learnableProbabilities[distanceEmbedding.embeddingData.label_name] > distanceEmbedding.distance)
{
// if so, replace the distance
learnableProbabilities[distanceEmbedding.embeddingData.label_name] = distanceEmbedding.distance;
}
}
else
{
// if not, add the distance to the dictionary
learnableProbabilities.Add(distanceEmbedding.embeddingData.label_name, distanceEmbedding.distance);
}
}
//UnityEngine.Debug.Log(learnableProbabilities.Aggregate("", (t, d) => $"{t}{d}, "));
foreach (Listener listener in listeners)
{
yield return listener.ProcessIncomingCall();
}
}
}
yield return null;
}
}
/// <summary>
/// Propper destruction on the Mediapipegraph
/// </summary>
private void OnDestroy()
{
if (webcamTexture != null)
{
webcamTexture.Stop();
}
if (graph != null)
{
try
{
graph.CloseInputStream("input_video").AssertOk();
graph.WaitUntilDone().AssertOk();
}
finally
{
graph.Dispose();
}
}
if (asyncPredictor != null)
{
asyncPredictor.Dispose();
}
}
/// <summary>
/// So long as there are cameras to use, you swap the camera you are using to another in the list.
/// </summary>
public void SwapCam()
{
if (WebCamTexture.devices.Length > 0)
{
// Stop the old camera
// If there was no camera playing before, then you dont have to reset the texture, as it wasn't assigned in the first place.
if (webcamTexture.isPlaying)
{
screen.texture = null;
webcamTexture.Stop();
webcamTexture = null;
}
// Find the new camera
camdex += 1;
camdex %= WebCamTexture.devices.Length;
// Start the new camera
WebCamDevice device = WebCamTexture.devices[camdex];
webcamTexture = new WebCamTexture(device.name);
screen.texture = webcamTexture;
webcamTexture.Play();
}
}
/// <summary>
/// Let the class know which Model it should load for the chosen Game/Course
/// </summary>
/// <param name="index"></param>
public void SetModel(ModelIndex index)
{
this.modelList.SetCurrentModel(index);
}
/// <summary>
/// Swaps the display screens
/// </summary>
public void SwapScreen(RawImage screen)
{
this.screen = screen;
if (webcamTexture != null)
{
float webcamAspect = (float)webcamTexture.width / (float)webcamTexture.height;
this.screen.rectTransform.sizeDelta = new Vector2(this.screen.rectTransform.sizeDelta.y * webcamAspect, (this.screen.rectTransform.sizeDelta.y));
this.screen.texture = webcamTexture;
}
}
}