using Mediapipe;
using Mediapipe.Unity;
using NatML;
using NatML.Features;
using NatML.Internal;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading.Tasks;
using UnityEngine;
using UnityEngine.UI;
///
///
///
public class SignPredictor : MonoBehaviour
{
///
/// Predictor class which is used to predict the sign using an MLEdgeModel
///
public class NatMLSignPredictor : IMLPredictor>
{
///
/// The MLEdgeModel used for predictions
///
private readonly MLEdgeModel edgeModel;
///
/// The type used to create features which are input for the model
///
private MLFeatureType featureType;
///
/// Creation of a NatMLSignPredictor instance
///
///
public NatMLSignPredictor(MLEdgeModel edgeModel)
{
this.edgeModel = edgeModel;
featureType = edgeModel.inputs[0];
}
///
/// Predicts the sign using the MLEdgeModel
///
///
///
public List Predict(params MLFeature[] inputs)
{
List predictions = null;
IMLEdgeFeature iedgeFeature = (IMLEdgeFeature)inputs[0];
MLEdgeFeature edgeFeature = iedgeFeature.Create(featureType);
MLFeatureCollection result = edgeModel.Predict(edgeFeature);
if (0 < result.Count)
{
predictions = new MLArrayFeature(result[0]).Flatten().ToArray().ToList();
predictions = predictions.ConvertAll((c) => Mathf.Exp(c));
float sum = predictions.Sum();
predictions = predictions.ConvertAll((c) => c / sum);
}
edgeFeature.Dispose();
result.Dispose();
return predictions;
}
///
/// Disposing the MLEdgeModel
///
public void Dispose()
{
edgeModel.Dispose();
}
}
public List listeners = new List();
///
/// Predictor which is used to create the asyncPredictor (should not be used if asyncPredictor exists)
///
private NatMLSignPredictor predictor;
///
/// The asynchronous predictor which is used to predict the sign using an MLEdgemodel
///
private MLAsyncPredictor> asyncPredictor;
///
/// Reference to the model used in the SignPredictor
///
private MLEdgeModel model;
///
/// Modellist used to change model using ModelIndex
///
public ModelList modelList;
///
/// Chosen model data based on the operating system
///
private MLModelData modelData;
///
/// Reference to the model info file
///
public TextAsset modelInfoFile;
///
/// Config file to set up the graph
///
[SerializeField]
private TextAsset configAsset;
///
/// Index to indicate which camera is being used
///
private int camdex = 0;
///
/// The screen object on which the video is displayed
///
[SerializeField]
private RawImage screen;
///
/// MediaPipe graph
///
private CalculatorGraph graph;
///
/// Resource manager for graph resources
///
private ResourceManager resourceManager;
///
/// Webcam texture
///
private WebCamTexture webcamTexture;
///
/// Input texture
///
private Texture2D inputTexture;
///
/// Screen pixel data
///
private Color32[] pixelData;
///
/// Stopwatch to give a timestamp to video frames
///
private Stopwatch stopwatch;
///
/// The mediapipe stream which contains the pose landmarks
///
private OutputStream posestream;
///
/// The mediapipe stream which contains the left hand landmarks
///
private OutputStream leftstream;
///
/// The mediapipe stream which contains the right hand landmarks
///
private OutputStream rightstream;
///
/// create precense stream
///
public OutputStream> presenceStream;
///
/// A keypointmanager which does normalization stuff, keeps track of the landmarks
///
private KeypointManager keypointManager;
///
/// Width of th webcam
///
private int width;
///
/// Height of the webcam
///
private int height;
///
/// The prediction of the sign predictor model
///
public Dictionary learnableProbabilities;
///
/// Bool indicating whether or not the resource manager has already been initialized
///
private static bool resourceManagerIsInitialized = false;
///
/// Google Mediapipe setup & run
///
/// IEnumerator
///
private IEnumerator Start()
{
// Webcam setup
if (WebCamTexture.devices.Length == 0)
{
throw new System.Exception("Web Camera devices are not found");
}
// Start the webcam
WebCamDevice webCamDevice = WebCamTexture.devices[0];
webcamTexture = new WebCamTexture(webCamDevice.name);
webcamTexture.Play();
yield return new WaitUntil(() => webcamTexture.width > 16);
// Set webcam aspect ratio
width = webcamTexture.width;
height = webcamTexture.height;
float webcamAspect = (float)webcamTexture.width / (float)webcamTexture.height;
screen.rectTransform.sizeDelta = new Vector2(screen.rectTransform.sizeDelta.y * webcamAspect, (screen.rectTransform.sizeDelta.y));
screen.texture = webcamTexture;
// TODO this method is kinda meh you should use
inputTexture = new Texture2D(width, height, TextureFormat.RGBA32, false);
pixelData = new Color32[width * height];
if (!resourceManagerIsInitialized)
{
resourceManager = new StreamingAssetsResourceManager();
yield return resourceManager.PrepareAssetAsync("pose_detection.bytes");
yield return resourceManager.PrepareAssetAsync("pose_landmark_full.bytes");
yield return resourceManager.PrepareAssetAsync("face_landmark.bytes");
yield return resourceManager.PrepareAssetAsync("hand_landmark_full.bytes");
yield return resourceManager.PrepareAssetAsync("face_detection_short_range.bytes");
yield return resourceManager.PrepareAssetAsync("hand_recrop.bytes");
yield return resourceManager.PrepareAssetAsync("handedness.txt");
resourceManagerIsInitialized = true;
}
stopwatch = new Stopwatch();
// Setting up the graph
graph = new CalculatorGraph(configAsset.text);
posestream = new OutputStream(graph, "pose_landmarks", "pose_landmarks_presence");
leftstream = new OutputStream(graph, "left_hand_landmarks", "left_hand_landmarks_presence");
rightstream = new OutputStream(graph, "right_hand_landmarks", "right_hand_landmarks_presence");
posestream.StartPolling().AssertOk();
leftstream.StartPolling().AssertOk();
rightstream.StartPolling().AssertOk();
graph.StartRun().AssertOk();
stopwatch.Start();
// Creating a KeypointManager
keypointManager = new KeypointManager(modelInfoFile);
// Check if a model is ready to load
yield return new WaitUntil(() => modelList.HasValidModel());
// Create Model
Task t = Task.Run(() => MLEdgeModel.Create(modelList.GetCurrentModel()));
yield return new WaitUntil(() => t.IsCompleted);
model = t.Result;
predictor = new NatMLSignPredictor(model);
asyncPredictor = predictor.ToAsync();
// Start the Coroutine
StartCoroutine(SignRecognitionCoroutine());
StartCoroutine(MediapipeCoroutine());
}
///
/// Coroutine which executes the mediapipe pipeline
///
///
private IEnumerator MediapipeCoroutine()
{
while (true)
{
inputTexture.SetPixels32(webcamTexture.GetPixels32(pixelData));
var imageFrame = new ImageFrame(ImageFormat.Types.Format.Srgba, width, height, width * 4, inputTexture.GetRawTextureData());
var currentTimestamp = stopwatch.ElapsedTicks / (System.TimeSpan.TicksPerMillisecond / 1000);
graph.AddPacketToInputStream("input_video", new ImageFramePacket(imageFrame, new Timestamp(currentTimestamp))).AssertOk();
yield return new WaitForEndOfFrame();
NormalizedLandmarkList _poseLandmarks = null;
NormalizedLandmarkList _leftHandLandmarks = null;
NormalizedLandmarkList _rightHandLandmarks = null;
yield return new WaitUntil(() => { posestream.TryGetNext(out _poseLandmarks); return true; });
yield return new WaitUntil(() => { leftstream.TryGetNext(out _leftHandLandmarks); return true; });
yield return new WaitUntil(() => { rightstream.TryGetNext(out _rightHandLandmarks); return true; });
keypointManager.AddLandmarks(_poseLandmarks, _leftHandLandmarks, _rightHandLandmarks);
}
}
///
/// Coroutine which calls the sign predictor model
///
///
private IEnumerator SignRecognitionCoroutine()
{
while (true)
{
List> inputData = keypointManager.GetKeypoints();
if (inputData != null && asyncPredictor.readyForPrediction)
{
// Getting the size of the input data
int framecount = inputData.Count;
int keypointsPerFrame = inputData[0].Count;
// Creating ArrayFeature
int[] shape = { framecount, keypointsPerFrame };
float[] input = new float[framecount * keypointsPerFrame];
int i = 0;
inputData.ForEach((e) => e.ForEach((f) => input[i++] = f));
MLArrayFeature feature = new MLArrayFeature(input, shape);
// Predicting
Task> task = Task.Run(async () => await asyncPredictor.Predict(feature));
yield return new WaitUntil(() => task.IsCompleted);
List result = task.Result;
if (0 < result.Count)
{
learnableProbabilities = new Dictionary();
// Temporary fix
List signs = new List()
{
"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M",
"N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"
};
for (int j = 0; j < result.Count; j++)
{
learnableProbabilities.Add(signs[j].ToUpper(), result[j]);
}
//Debug.Log($"prob = [{learnableProbabilities.Aggregate(" ", (t, kv) => $"{t}{kv.Key}:{kv.Value} ")}]");
foreach (Listener listener in listeners)
{
yield return listener.ProcessIncomingCall();
}
}
else
{
// Wait until next frame
yield return null;
}
}
yield return null;
}
}
///
/// Propper destruction on the Mediapipegraph
///
private void OnDestroy()
{
if (webcamTexture != null)
{
webcamTexture.Stop();
}
if (graph != null)
{
try
{
graph.CloseInputStream("input_video").AssertOk();
graph.WaitUntilDone().AssertOk();
}
finally
{
graph.Dispose();
}
}
if (asyncPredictor != null)
{
asyncPredictor.Dispose();
}
}
///
/// So long as there are cameras to use, you swap the camera you are using to another in the list.
///
public void SwapCam()
{
if (WebCamTexture.devices.Length > 0)
{
// Stop the old camera
// If there was no camera playing before, then you dont have to reset the texture, as it wasn't assigned in the first place.
if (webcamTexture.isPlaying)
{
screen.texture = null;
webcamTexture.Stop();
webcamTexture = null;
}
// Find the new camera
camdex += 1;
camdex %= WebCamTexture.devices.Length;
// Start the new camera
WebCamDevice device = WebCamTexture.devices[camdex];
webcamTexture = new WebCamTexture(device.name);
screen.texture = webcamTexture;
webcamTexture.Play();
}
}
public void SetModel(ModelIndex index)
{
this.modelList.SetCurrentModel(index);
}
///
/// Swaps the display screens
///
public void SwapScreen(RawImage screen)
{
this.screen = screen;
this.screen.texture = webcamTexture;
}
}