using Mediapipe; using Mediapipe.Unity; using NatML; using NatML.Features; using NatML.Internal; using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Threading.Tasks; using UnityEngine; using UnityEngine.UI; [System.Serializable] public class EmbeddingData { public float[] embeddings; public string label_name; public int labels; } [System.Serializable] public class EmbeddingDataList { public List dataList; } public class DistanceEmbedding { public float distance; public EmbeddingData embeddingData; public DistanceEmbedding(float distance, EmbeddingData embeddingData) { this.distance = distance; this.embeddingData = embeddingData; } } public class DistanceComparer : IComparer { public int Compare(DistanceEmbedding x, DistanceEmbedding y) { return x.distance.CompareTo(y.distance); } } /// /// /// public class SignPredictor : MonoBehaviour { /// /// Predictor class which is used to predict the sign using an MLEdgeModel /// public class NatMLSignPredictorEmbed : IMLPredictor> { /// /// The MLEdgeModel used for predictions /// private readonly MLEdgeModel edgeModel; /// /// The type used to create features which are input for the model /// private MLFeatureType featureType; /// /// Creation of a NatMLSignPredictor instance /// /// public NatMLSignPredictorEmbed(MLEdgeModel edgeModel) { this.edgeModel = edgeModel; featureType = edgeModel.inputs[0]; } /// /// Predicts the sign using the MLEdgeModel /// /// /// public List Predict(params MLFeature[] inputs) { List predictions = null; IMLEdgeFeature iedgeFeature = (IMLEdgeFeature)inputs[0]; MLEdgeFeature edgeFeature = iedgeFeature.Create(featureType); MLFeatureCollection result = edgeModel.Predict(edgeFeature); if (0 < result.Count) { predictions = new MLArrayFeature(result[0]).Flatten().ToArray().ToList(); } edgeFeature.Dispose(); result.Dispose(); return predictions; } /// /// Disposing the MLEdgeModel /// public void Dispose() { edgeModel.Dispose(); } } /// /// Predictor class which is used to predict the sign using an MLEdgeModel /// public class NatMLSignPredictor : IMLPredictor> { /// /// The MLEdgeModel used for predictions /// private readonly MLEdgeModel edgeModel; /// /// The type used to create features which are input for the model /// private MLFeatureType featureType; /// /// Creation of a NatMLSignPredictor instance /// /// public NatMLSignPredictor(MLEdgeModel edgeModel) { this.edgeModel = edgeModel; featureType = edgeModel.inputs[0]; } /// /// Predicts the sign using the MLEdgeModel /// /// /// public List Predict(params MLFeature[] inputs) { List predictions = null; IMLEdgeFeature iedgeFeature = (IMLEdgeFeature)inputs[0]; MLEdgeFeature edgeFeature = iedgeFeature.Create(featureType); MLFeatureCollection result = edgeModel.Predict(edgeFeature); if (0 < result.Count) { predictions = new MLArrayFeature(result[0]).Flatten().ToArray().ToList(); predictions = predictions.ConvertAll((c) => Mathf.Exp(c)); float sum = predictions.Sum(); predictions = predictions.ConvertAll((c) => c / sum); } edgeFeature.Dispose(); result.Dispose(); return predictions; } /// /// Disposing the MLEdgeModel /// public void Dispose() { edgeModel.Dispose(); } } public List listeners = new List(); /// /// Predictor which is used to create the asyncPredictor (should not be used if asyncPredictor exists) /// private NatMLSignPredictorEmbed predictor_embed; private NatMLSignPredictor predictor; /// /// The asynchronous predictor which is used to predict the sign using an MLEdgemodel /// private MLAsyncPredictor> asyncPredictor; /// /// Reference to the model used in the SignPredictor /// private MLEdgeModel model; /// /// Modellist used to change model using ModelIndex /// public ModelList modelList; /// /// Chosen model data based on the operating system /// private MLModelData modelData; /// /// Reference to the model info file /// public TextAsset modelInfoFile; public TextAsset modelInfoFileEmbedding; /// /// Config file to set up the graph /// [SerializeField] private TextAsset configAsset; /// /// Index to indicate which camera is being used /// private int camdex = 0; /// /// The screen object on which the video is displayed /// [SerializeField] private RawImage screen; /// /// MediaPipe graph /// private CalculatorGraph graph; /// /// Resource manager for graph resources /// private ResourceManager resourceManager; /// /// Webcam texture /// private WebCamTexture webcamTexture = null; /// /// Input texture /// private Texture2D inputTexture; /// /// Screen pixel data /// private Color32[] pixelData; /// /// Stopwatch to give a timestamp to video frames /// private Stopwatch stopwatch; /// /// The mediapipe stream which contains the pose landmarks /// private OutputStream posestream; /// /// The mediapipe stream which contains the left hand landmarks /// private OutputStream leftstream; /// /// The mediapipe stream which contains the right hand landmarks /// private OutputStream rightstream; /// /// create precense stream /// public OutputStream> presenceStream; /// /// A keypointmanager which does normalization stuff, keeps track of the landmarks /// private KeypointManager keypointManager; /// /// A keypointmanager which does normalization stuff, keeps track of the landmarks (for embedding model) /// private KeypointManagerEmbedding keypointManagerEmbedding; /// /// Width of th webcam /// private int width; /// /// Height of the webcam /// private int height; /// /// The prediction of the sign predictor model /// public Dictionary learnableProbabilities; /// /// Bool indicating whether or not the resource manager has already been initialized /// private static bool resourceManagerIsInitialized = false; private EmbeddingDataList embeddingDataList; private ModelIndex modelID; /// /// Google Mediapipe setup & run /// /// IEnumerator /// private IEnumerator Start() { // Webcam setup if (WebCamTexture.devices.Length == 0) { throw new System.Exception("Web Camera devices are not found"); } // Start the webcam WebCamDevice webCamDevice = WebCamTexture.devices[0]; webcamTexture = new WebCamTexture(webCamDevice.name); webcamTexture.Play(); yield return new WaitUntil(() => webcamTexture.width > 16); // Set webcam aspect ratio width = webcamTexture.width; height = webcamTexture.height; float webcamAspect = (float)webcamTexture.width / (float)webcamTexture.height; screen.rectTransform.sizeDelta = new Vector2(screen.rectTransform.sizeDelta.y * webcamAspect, (screen.rectTransform.sizeDelta.y)); screen.texture = webcamTexture; // TODO this method is kinda meh you should use inputTexture = new Texture2D(width, height, TextureFormat.RGBA32, false); pixelData = new Color32[width * height]; if (!resourceManagerIsInitialized) { resourceManager = new StreamingAssetsResourceManager(); yield return resourceManager.PrepareAssetAsync("pose_detection.bytes"); yield return resourceManager.PrepareAssetAsync("pose_landmark_full.bytes"); yield return resourceManager.PrepareAssetAsync("face_landmark.bytes"); yield return resourceManager.PrepareAssetAsync("hand_landmark_full.bytes"); yield return resourceManager.PrepareAssetAsync("face_detection_short_range.bytes"); yield return resourceManager.PrepareAssetAsync("hand_recrop.bytes"); yield return resourceManager.PrepareAssetAsync("handedness.txt"); resourceManagerIsInitialized = true; } stopwatch = new Stopwatch(); // Setting up the graph graph = new CalculatorGraph(configAsset.text); posestream = new OutputStream(graph, "pose_landmarks", "pose_landmarks_presence"); leftstream = new OutputStream(graph, "left_hand_landmarks", "left_hand_landmarks_presence"); rightstream = new OutputStream(graph, "right_hand_landmarks", "right_hand_landmarks_presence"); posestream.StartPolling().AssertOk(); leftstream.StartPolling().AssertOk(); rightstream.StartPolling().AssertOk(); graph.StartRun().AssertOk(); stopwatch.Start(); // Check if a model is ready to load yield return new WaitUntil(() => modelList.HasValidModel()); // Create Model Task t = Task.Run(() => MLEdgeModel.Create(modelList.GetCurrentModel())); yield return new WaitUntil(() => t.IsCompleted); model = t.Result; modelID = modelList.GetCurrentModelIndex(); if (modelID == ModelIndex.FINGERSPELLING) { predictor = new NatMLSignPredictor(model); asyncPredictor = predictor.ToAsync(); // Creating a KeypointManager keypointManager = new KeypointManager(modelInfoFile); StartCoroutine(SignRecognitionCoroutine()); StartCoroutine(MediapipeCoroutine()); } else { predictor_embed = new NatMLSignPredictorEmbed(model); asyncPredictor = predictor_embed.ToAsync(); // Creating a KeypointManager keypointManagerEmbedding = new KeypointManagerEmbedding(); // read the embedding data embeddingDataList = JsonUtility.FromJson($"{{\"dataList\":{modelInfoFileEmbedding}}}"); // Start the Coroutine StartCoroutine(SignRecognitionCoroutineEmbed()); StartCoroutine(MediapipeCoroutineEmbed()); } } /// /// Coroutine which executes the mediapipe pipeline /// /// private IEnumerator MediapipeCoroutine() { while (true) { inputTexture.SetPixels32(webcamTexture.GetPixels32(pixelData)); var imageFrame = new ImageFrame(ImageFormat.Types.Format.Srgba, width, height, width * 4, inputTexture.GetRawTextureData()); var currentTimestamp = stopwatch.ElapsedTicks / (System.TimeSpan.TicksPerMillisecond / 1000); graph.AddPacketToInputStream("input_video", new ImageFramePacket(imageFrame, new Timestamp(currentTimestamp))).AssertOk(); yield return new WaitForEndOfFrame(); NormalizedLandmarkList _poseLandmarks = null; NormalizedLandmarkList _leftHandLandmarks = null; NormalizedLandmarkList _rightHandLandmarks = null; yield return new WaitUntil(() => { posestream.TryGetNext(out _poseLandmarks); return true; }); yield return new WaitUntil(() => { leftstream.TryGetNext(out _leftHandLandmarks); return true; }); yield return new WaitUntil(() => { rightstream.TryGetNext(out _rightHandLandmarks); return true; }); keypointManager.AddLandmarks(_poseLandmarks, _leftHandLandmarks, _rightHandLandmarks); } } /// /// Coroutine which executes the mediapipe pipeline /// /// private IEnumerator MediapipeCoroutineEmbed() { while (true) { inputTexture.SetPixels32(webcamTexture.GetPixels32(pixelData)); var imageFrame = new ImageFrame(ImageFormat.Types.Format.Srgba, width, height, width * 4, inputTexture.GetRawTextureData()); var currentTimestamp = stopwatch.ElapsedTicks / (System.TimeSpan.TicksPerMillisecond / 1000); graph.AddPacketToInputStream("input_video", new ImageFramePacket(imageFrame, new Timestamp(currentTimestamp))).AssertOk(); yield return new WaitForEndOfFrame(); NormalizedLandmarkList _poseLandmarks = null; NormalizedLandmarkList _leftHandLandmarks = null; NormalizedLandmarkList _rightHandLandmarks = null; yield return new WaitUntil(() => { posestream.TryGetNext(out _poseLandmarks); return true; }); yield return new WaitUntil(() => { leftstream.TryGetNext(out _leftHandLandmarks); return true; }); yield return new WaitUntil(() => { rightstream.TryGetNext(out _rightHandLandmarks); return true; }); keypointManagerEmbedding.AddLandmarks(_poseLandmarks, _leftHandLandmarks, _rightHandLandmarks); } } private float MinkowskiDistance(List x, float[] y, int p) { int dimensions = x.Count; float sum = 0; for (int i = 0; i < dimensions; i++) { sum += Mathf.Pow(Mathf.Abs(x[i] - y[i]), p); } return Mathf.Pow(sum, 1.0f / p); } private List GetDistances(List embedding, int p = 2) { List distances = new List(); DistanceComparer comparer = new DistanceComparer(); foreach (EmbeddingData data in embeddingDataList.dataList) { float distance = MinkowskiDistance(embedding, data.embeddings, p); DistanceEmbedding newDistanceEmbedding = new DistanceEmbedding(distance, data); // Find the appropriate index to insert the new item to maintain the sorted order int index = distances.BinarySearch(newDistanceEmbedding, comparer); // If the index is negative, it represents the bitwise complement of the nearest larger element if (index < 0) { index = ~index; } // Insert the new item at the appropriate position distances.Insert(index, newDistanceEmbedding); } return distances; } /// /// Coroutine which calls the sign predictor model /// /// private IEnumerator SignRecognitionCoroutine() { while (true) { List> inputData = keypointManager.GetKeypoints(); if (inputData != null && asyncPredictor.readyForPrediction) { // Getting the size of the input data int framecount = inputData.Count; int keypointsPerFrame = inputData[0].Count; // Creating ArrayFeature int[] shape = { framecount, keypointsPerFrame }; float[] input = new float[framecount * keypointsPerFrame]; int i = 0; inputData.ForEach((e) => e.ForEach((f) => input[i++] = f)); MLArrayFeature feature = new MLArrayFeature(input, shape); // Predicting Task> task = Task.Run(async () => await asyncPredictor.Predict(feature)); yield return new WaitUntil(() => task.IsCompleted); List result = task.Result; if (0 < result.Count) { learnableProbabilities = new Dictionary(); // Temporary fix List signs = new List() { "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z" }; for (int j = 0; j < result.Count; j++) { learnableProbabilities.Add(signs[j].ToUpper(), result[j]); } foreach (Listener listener in listeners) { yield return listener.ProcessIncomingCall(); } } else { // Wait until next frame yield return null; } } yield return null; } } /// /// Coroutine which calls the sign predictor embedding model /// /// private IEnumerator SignRecognitionCoroutineEmbed() { while (true) { List>> inputData = keypointManagerEmbedding.GetKeypoints(); if (inputData != null && asyncPredictor.readyForPrediction) { // Getting the size of the input data int framecount = inputData.Count; int keypointsPerFrame = inputData[0].Count; // Creating ArrayFeature int[] shape = { 1, framecount, keypointsPerFrame, 2 }; float[] input = new float[framecount * keypointsPerFrame * 2]; int i = 0; inputData.ForEach((e) => e.ForEach((f) => f.ForEach((k) => input[i++] = k))); MLArrayFeature feature = new MLArrayFeature(input, shape); // Predicting Task> task = Task.Run(async () => await asyncPredictor.Predict(feature)); yield return new WaitUntil(() => task.IsCompleted); List result = task.Result; if (0 < result.Count) { List distances = GetDistances(result, 2); var probs = new Dictionary(); for (int j = 0; j < distances.Count; j++) { DistanceEmbedding distanceEmbedding = distances[j]; // check if already in dictionary if (probs.ContainsKey(distanceEmbedding.embeddingData.label_name)) { // if so, check if the distance is smaller if (probs[distanceEmbedding.embeddingData.label_name] > distanceEmbedding.distance) { // if so, replace the distance probs[distanceEmbedding.embeddingData.label_name] = distanceEmbedding.distance; } } else { // if not, add the distance to the dictionary probs.Add(distanceEmbedding.embeddingData.label_name, distanceEmbedding.distance); } } // convert distances to probabilities, the closer to 1.5 the better the prediction var newProbs = new Dictionary(); float sum = 0.0f; foreach (KeyValuePair entry in probs) { float probability = 1 / (1 + Mathf.Exp(2 * (entry.Value - 1.85f))); newProbs.Add(entry.Key, probability); sum += probability; } learnableProbabilities = new Dictionary(); foreach (var kv in newProbs) learnableProbabilities.Add(kv.Key, kv.Value / sum); //UnityEngine.Debug.Log($"{learnableProbabilities.Aggregate("", (t, e) => $"{t}{e.Key}={e.Value}, ")}"); foreach (Listener listener in listeners) { yield return listener.ProcessIncomingCall(); } } } yield return null; } } /// /// Propper destruction on the Mediapipegraph /// private void OnDestroy() { if (webcamTexture != null) { webcamTexture.Stop(); } if (graph != null) { try { graph.CloseInputStream("input_video").AssertOk(); graph.WaitUntilDone().AssertOk(); } finally { graph.Dispose(); } } if (asyncPredictor != null) { asyncPredictor.Dispose(); } } /// /// So long as there are cameras to use, you swap the camera you are using to another in the list. /// public void SwapCam() { if (WebCamTexture.devices.Length > 0) { // Stop the old camera // If there was no camera playing before, then you dont have to reset the texture, as it wasn't assigned in the first place. if (webcamTexture.isPlaying) { screen.texture = null; webcamTexture.Stop(); webcamTexture = null; } // Find the new camera camdex += 1; camdex %= WebCamTexture.devices.Length; // Start the new camera WebCamDevice device = WebCamTexture.devices[camdex]; webcamTexture = new WebCamTexture(device.name); screen.texture = webcamTexture; webcamTexture.Play(); } } public void SetModel(ModelIndex index) { this.modelList.SetCurrentModel(index); } /// /// Swaps the display screens /// public void SwapScreen(RawImage screen) { this.screen = screen; //width = webcamTexture.width; //height = webcamTexture.height; if (webcamTexture != null) { float webcamAspect = (float)webcamTexture.width / (float)webcamTexture.height; this.screen.rectTransform.sizeDelta = new Vector2(this.screen.rectTransform.sizeDelta.y * webcamAspect, (this.screen.rectTransform.sizeDelta.y)); this.screen.texture = webcamTexture; } } }