Files
unity-application/Assets/MediaPipeUnity/Scripts/SignPredictor.cs
2023-04-02 12:27:59 +00:00

416 lines
15 KiB
C#

// Copyright (c) 2021 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
// ATTENTION!: This code is for a tutorial.
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Unity.Barracuda;
using UnityEngine;
using UnityEngine.UI;
namespace Mediapipe.Unity.Tutorial
{
public class SignPredictor : MonoBehaviour
{
/// <summary>
/// ModelList, used to change model using ModelIndex
/// </summary>
public ModelList modelList;
/// <summary>
/// Reference to the model info file
/// </summary>
public TextAsset modelInfoFile;
/// <summary>
/// Config file to set up the graph
/// </summary>
[SerializeField]
private TextAsset configAsset;
/// <summary>
/// Index to indicate which camera is being used
/// </summary>
private int camdex = 0;
/// <summary>
/// The screen object on which the video is displayed
/// </summary>
[SerializeField]
private RawImage screen;
/// <summary>
/// A secondary optional screen object on which the video is displayed
/// </summary>
[SerializeField]
private RawImage screen2;
/// <summary>
/// MediaPipe graph
/// </summary>
private CalculatorGraph graph;
/// <summary>
/// Resource manager for graph resources
/// </summary>
private ResourceManager resourceManager;
/// <summary>
/// Webcam texture
/// </summary>
private WebCamTexture webcamTexture;
/// <summary>
/// Input texture
/// </summary>
private Texture2D inputTexture;
/// <summary>
/// Screen pixel data
/// </summary>
private Color32[] pixelData;
/// <summary>
/// Stopwatch to give a timestamp to video frames
/// </summary>
private Stopwatch stopwatch;
/// <summary>
/// The mediapipe stream which contains the pose landmarks
/// </summary>
private OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList> posestream;
/// <summary>
/// The mediapipe stream which contains the left hand landmarks
/// </summary>
private OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList> leftstream;
/// <summary>
/// The mediapipe stream which contains the right hand landmarks
/// </summary>
private OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList> rightstream;
/// <summary>
/// create precense stream
/// </summary>
public OutputStream<DetectionVectorPacket, List<Detection>> presenceStream;
/// <summary>
/// A keypointmanager which does normalization stuff, keeps track of the landmarks
/// </summary>
private KeypointManager keypointManager;
/// <summary>
/// The worker on which we schedule the signpredictor model execution
/// </summary>
private IWorker worker;
/// <summary>
/// Width of th webcam
/// </summary>
private int width;
/// <summary>
/// Height of the webcam
/// </summary>
private int height;
/// <summary>
/// The enumerator of the worker which executes the sign predictor model
/// </summary>
private IEnumerator enumerator;
/// <summary>
/// The prediction of the sign predictor model
/// </summary>
public Dictionary<string, float> learnableProbabilities;
/// <summary>
/// Bool indicating whether or not the resource manager has already been initialized
/// </summary>
private static bool resourceManagerIsInitialized = false;
/// <summary>
/// an inputTensor for the sign predictor
/// </summary>
private Tensor inputTensor;
public List<Listener> listeners = new List<Listener>();
/// <summary>
/// Google Mediapipe setup & run
/// </summary>
/// <returns>IEnumerator</returns>
/// <exception cref="System.Exception"></exception>
private IEnumerator Start()
{
// Webcam setup
if (WebCamTexture.devices.Length == 0)
{
throw new System.Exception("Web Camera devices are not found");
}
// Start the webcam
WebCamDevice webCamDevice = WebCamTexture.devices[0];
webcamTexture = new WebCamTexture(webCamDevice.name);
webcamTexture.Play();
yield return new WaitUntil(() => webcamTexture.width > 16);
// Set webcam aspect ratio
width = webcamTexture.width;
height = webcamTexture.height;
float webcamAspect = (float)webcamTexture.width / (float)webcamTexture.height;
screen.rectTransform.sizeDelta = new Vector2(screen.rectTransform.sizeDelta.y * webcamAspect, (screen.rectTransform.sizeDelta.y));
screen.texture = webcamTexture;
if (screen2 != null)
{
screen2.rectTransform.sizeDelta = new Vector2(screen2.rectTransform.sizeDelta.y * webcamAspect, (screen2.rectTransform.sizeDelta.y));
}
if (modelList.GetCurrentModel() != null)
{
// TODO this method is kinda meh you should use
inputTexture = new Texture2D(width, height, TextureFormat.RGBA32, false);
pixelData = new Color32[width * height];
if (!resourceManagerIsInitialized)
{
resourceManager = new StreamingAssetsResourceManager();
yield return resourceManager.PrepareAssetAsync("pose_detection.bytes");
yield return resourceManager.PrepareAssetAsync("pose_landmark_full.bytes");
yield return resourceManager.PrepareAssetAsync("face_landmark.bytes");
yield return resourceManager.PrepareAssetAsync("hand_landmark_full.bytes");
yield return resourceManager.PrepareAssetAsync("face_detection_short_range.bytes");
yield return resourceManager.PrepareAssetAsync("hand_recrop.bytes");
yield return resourceManager.PrepareAssetAsync("handedness.txt");
resourceManagerIsInitialized = true;
}
stopwatch = new Stopwatch();
// Setting up the graph
graph = new CalculatorGraph(configAsset.text);
posestream = new OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList>(graph, "pose_landmarks", "pose_landmarks_presence");
leftstream = new OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList>(graph, "left_hand_landmarks", "left_hand_landmarks_presence");
rightstream = new OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList>(graph, "right_hand_landmarks", "right_hand_landmarks_presence");
posestream.StartPolling().AssertOk();
leftstream.StartPolling().AssertOk();
rightstream.StartPolling().AssertOk();
graph.StartRun().AssertOk();
stopwatch.Start();
keypointManager = new KeypointManager(modelInfoFile);
// check if model exists at path
//var model = ModelLoader.Load(Resources.Load<NNModel>("Models/Fingerspelling/model_A-L"));
worker = modelList.GetCurrentModel().CreateWorker();
StartCoroutine(SignRecognitionCoroutine());
StartCoroutine(MediapipeCoroutine());
}
}
/// <summary>
/// Called at the start of course/Minigame, will set the model before the start of SIgnPredictor is called.
/// </summary>
/// <param name="index">The index of the model to be used</param>
public void SetModel(ModelIndex index)
{
this.modelList.SetCurrentModel(index);
}
/// <summary>
/// Coroutine which executes the mediapipe pipeline
/// </summary>
/// <returns></returns>
private IEnumerator MediapipeCoroutine()
{
while (true)
{
inputTexture.SetPixels32(webcamTexture.GetPixels32(pixelData));
var imageFrame = new ImageFrame(ImageFormat.Types.Format.Srgba, width, height, width * 4, inputTexture.GetRawTextureData<byte>());
var currentTimestamp = stopwatch.ElapsedTicks / (System.TimeSpan.TicksPerMillisecond / 1000);
graph.AddPacketToInputStream("input_video", new ImageFramePacket(imageFrame, new Timestamp(currentTimestamp))).AssertOk();
//Debug.Log(Time.timeAsDouble + " Added new packet to mediapipe graph");
yield return new WaitForEndOfFrame();
NormalizedLandmarkList _poseLandmarks = null;
NormalizedLandmarkList _leftHandLandmarks = null;
NormalizedLandmarkList _rightHandLandmarks = null;
//Debug.Log("Extracting keypoints");
yield return new WaitUntil(() => { posestream.TryGetNext(out _poseLandmarks, false); return true; });
yield return new WaitUntil(() => { leftstream.TryGetNext(out _leftHandLandmarks, false); return true; });
yield return new WaitUntil(() => { rightstream.TryGetNext(out _rightHandLandmarks, false); return true; });
//Debug.Log(Time.timeAsDouble + " Retrieved landmarks ");
keypointManager.AddLandmarks(_poseLandmarks, _leftHandLandmarks, _rightHandLandmarks);
}
}
/// <summary>
/// Coroutine which calls the sign predictor model
/// </summary>
/// <returns></returns>
private IEnumerator SignRecognitionCoroutine()
{
while (true)
{
List<List<float>> input = keypointManager.GetKeypoints();
if (input != null)
{
//UnityEngine.Debug.Log("input: " + input.Count);
int frameCount = input.Count;
int keypoints_per_frame = input[0].Count;
// Create a tensor with the input
inputTensor = new Tensor(frameCount, keypoints_per_frame);
// Fill the tensor with the input
for (int i = 0; i < frameCount; i++)
{
for (int j = 0; j < keypoints_per_frame; j++)
{
inputTensor[i, j] = input[i][j];
}
}
int stepsPerFrame = 190;
enumerator = worker.StartManualSchedule(inputTensor);
int step = 0;
while (enumerator.MoveNext())
{
if (++step % stepsPerFrame == 0)
{
//Debug.Log(Time.timeAsDouble + " : " + step);
yield return null;
}
}
var output = worker.PeekOutput();
inputTensor.Dispose();
// Get the output as an array
float[] outputArray = output.ToReadOnlyArray();
//Debug.Log($"out = [{outputArray.Aggregate(" ", (t, f) => $"{t}{f} ")}]");
// Calculate the softmax of the output
float max = outputArray.Max();
float[] softmaxedOutput = outputArray.Select(x => Mathf.Exp(x - max)).ToArray();
float sum = softmaxedOutput.Sum();
float[] softmaxedOutput2 = softmaxedOutput.Select(x => x / sum).ToArray();
// Get the index of the highest probability
int maxIndex = softmaxedOutput2.ToList().IndexOf(softmaxedOutput2.Max());
// Get the letter from the index
char letter = (char)(maxIndex + 65);
float accuracy = (Mathf.RoundToInt(softmaxedOutput2[maxIndex] * 100));
// Set the letterProbabilities, currently used by Courses
learnableProbabilities = new Dictionary<string, float>();
for (int i = 0; i < softmaxedOutput2.Length; i++)
{
learnableProbabilities.Add(((char)(i + 65)).ToString(), softmaxedOutput2[i]);
}
//Debug.Log($"prob = [{learnableProbabilities.Aggregate(" ", (t, kv) => $"{t}{kv.Key}:{kv.Value} ")}]");
foreach(Listener listener in listeners)
{
yield return listener.ProcessIncomingCall();
}
}
else
{
// Wait until next frame
yield return null;
}
}
}
/// <summary>
/// Propper destruction on the Mediapipegraph
/// </summary>
private void OnDestroy()
{
if (webcamTexture != null)
{
webcamTexture.Stop();
}
if (graph != null)
{
try
{
graph.CloseInputStream("input_video").AssertOk();
graph.WaitUntilDone().AssertOk();
}
finally
{
graph.Dispose();
}
}
// inputTensor must still be disposed, if it exists
inputTensor?.Dispose();
worker?.Dispose();
}
/// <summary>
/// So long as there are cameras to use, you swap the camera you are using to another in the list.
/// </summary>
public void SwapCam()
{
if (WebCamTexture.devices.Length > 0)
{
// Stop the old camera
// If there was no camera playing before, then you dont have to reset the texture, as it wasn't assigned in the first place.
if (webcamTexture.isPlaying)
{
screen.texture = null;
webcamTexture.Stop();
webcamTexture = null;
}
// Find the new camera
camdex += 1;
camdex %= WebCamTexture.devices.Length;
// Start the new camera
WebCamDevice device = WebCamTexture.devices[camdex];
webcamTexture = new WebCamTexture(device.name);
screen.texture = webcamTexture;
webcamTexture.Play();
}
}
/// <summary>
/// Swaps the display screens
/// </summary>
public void SwapScreen()
{
if(screen2.texture == null && screen.texture != null)
{
screen2.texture = webcamTexture;
screen.texture = null;
}
else if (screen2.texture != null && screen.texture == null)
{
screen.texture = webcamTexture;
screen2.texture = null;
}
}
}
}