Files
unity-application/Assets/MediaPipeUnity/Scripts/Wesign_extractor.cs
2023-03-18 19:53:17 +00:00

344 lines
13 KiB
C#

// Copyright (c) 2021 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
// ATTENTION!: This code is for a tutorial.
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using TMPro;
using Unity.Barracuda;
using UnityEngine;
using UnityEngine.UI;
using Debug = UnityEngine.Debug;
namespace Mediapipe.Unity.Tutorial
{
public class Wesign_extractor : MonoBehaviour
{
/// <summary>
/// Config file to set up the graph
/// </summary>
[SerializeField] private TextAsset _configAsset;
/// <summary>
/// The screen object on which the video is displayed
/// </summary>
[SerializeField] private RawImage _screen;
/// <summary>
/// MediaPipe graph
/// </summary>
private CalculatorGraph _graph;
/// <summary>
/// Resource manager for graph resources
/// </summary>
private ResourceManager _resourceManager;
/// <summary>
/// Webcam texture
/// </summary>
private WebCamTexture _webCamTexture;
/// <summary>
/// Input texture
/// </summary>
private Texture2D _inputTexture;
/// <summary>
/// Screen pixel data
/// </summary>
private Color32[] _pixelData;
/// <summary>
/// Stopwatch to give a timestamp to video frames
/// </summary>
private Stopwatch _stopwatch;
/// <summary>
/// The mediapipe stream which contains the pose landmarks
/// </summary>
private OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList> posestream;
/// <summary>
/// The mediapipe stream which contains the left hand landmarks
/// </summary>
private OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList> leftstream;
/// <summary>
/// The mediapipe stream which contains the right hand landmarks
/// </summary>
private OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList> rightstream;
/// <summary>
/// create precense stream
/// </summary>
public OutputStream<DetectionVectorPacket, List<Detection>> _presenceStream;
/// <summary>
/// A keypointmanager which does normalization stuff, keeps track of the landmarks
/// </summary>
private KeypointManager k;
/// <summary>
/// The worker on which we schedule the signpredictor model execution
/// </summary>
private IWorker worker;
/// <summary>
/// Width of th webcam
/// </summary>
private int _width;
/// <summary>
/// Height of the webcam
/// </summary>
private int _height;
/// <summary>
/// ?The mediapipe stream which contains the tracked detections
/// </summary>
private const string _TrackedDetectionsStreamName = "tracked_detections";
/// <summary>
/// ?The mediapipe stream which contains the tracked detections
/// </summary>
private OutputStream<DetectionVectorPacket, List<Detection>> _trackedDetectionsStream;
/// <summary>
/// The enumerator of the worker which executes the sign predictor model
/// </summary>
private IEnumerator enumerator;
/// <summary>
/// The prediction of the sign predictor model
/// </summary>
public Dictionary<char, float> letterProbabilities;
/// <summary>
/// Bool indicating whether or not the resource manager has already been initialized
/// </summary>
private static bool resourceManagerIsInitialized = false;
/// <summary>
/// an inputTensor for the sign predictor
/// </summary>
private Tensor inputTensor;
/// <summary>
/// Google Mediapipe setup & run
/// </summary>
/// <returns> IEnumerator </returns>
/// <exception cref="System.Exception"></exception>
private IEnumerator Start()
{
Debug.Log("starting ...");
// Webcam setup
if (WebCamTexture.devices.Length == 0)
{
throw new System.Exception("Web Camera devices are not found");
}
// Start the webcam
WebCamDevice webCamDevice = WebCamTexture.devices[0];
_webCamTexture = new WebCamTexture(webCamDevice.name);
_webCamTexture.Play();
yield return new WaitUntil(() => _webCamTexture.width > 16);
// Set webcam aspect ratio
_width = _webCamTexture.width;
_height = _webCamTexture.height;
float webcamAspect = (float)_webCamTexture.width / (float)_webCamTexture.height;
_screen.rectTransform.sizeDelta = new Vector2(_screen.rectTransform.sizeDelta.y * webcamAspect, (_screen.rectTransform.sizeDelta.y));
_screen.texture = _webCamTexture;
// TODO this method is kinda meh you should use
_inputTexture = new Texture2D(_width, _height, TextureFormat.RGBA32, false);
_pixelData = new Color32[_width * _height];
if (!resourceManagerIsInitialized)
{
_resourceManager = new StreamingAssetsResourceManager();
yield return _resourceManager.PrepareAssetAsync("pose_detection.bytes");
yield return _resourceManager.PrepareAssetAsync("pose_landmark_full.bytes");
yield return _resourceManager.PrepareAssetAsync("face_landmark.bytes");
yield return _resourceManager.PrepareAssetAsync("hand_landmark_full.bytes");
yield return _resourceManager.PrepareAssetAsync("face_detection_short_range.bytes");
yield return _resourceManager.PrepareAssetAsync("hand_recrop.bytes");
yield return _resourceManager.PrepareAssetAsync("handedness.txt");
resourceManagerIsInitialized = true;
}
_stopwatch = new Stopwatch();
// Setting up the graph
_graph = new CalculatorGraph(_configAsset.text);
posestream = new OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList>(_graph, "pose_landmarks", "pose_landmarks_presence");
leftstream = new OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList>(_graph, "left_hand_landmarks", "left_hand_landmarks_presence");
rightstream = new OutputStream<NormalizedLandmarkListPacket, NormalizedLandmarkList>(_graph, "right_hand_landmarks", "right_hand_landmarks_presence");
posestream.StartPolling().AssertOk();
leftstream.StartPolling().AssertOk();
rightstream.StartPolling().AssertOk();
_graph.StartRun().AssertOk();
_stopwatch.Start();
k = new KeypointManager();
// check if model exists at path
var model = ModelLoader.Load(Resources.Load<NNModel>("Models/Fingerspelling/model_A-L"));
worker = model.CreateWorker();
StartCoroutine(SignRecognitionCoroutine());
StartCoroutine(MediapipeCoroutine());
}
/// <summary>
/// Coroutine which executes the mediapipe pipeline
/// </summary>
/// <returns></returns>
private IEnumerator MediapipeCoroutine()
{
while (true)
{
_inputTexture.SetPixels32(_webCamTexture.GetPixels32(_pixelData));
var imageFrame = new ImageFrame(ImageFormat.Types.Format.Srgba, _width, _height, _width * 4, _inputTexture.GetRawTextureData<byte>());
var currentTimestamp = _stopwatch.ElapsedTicks / (System.TimeSpan.TicksPerMillisecond / 1000);
_graph.AddPacketToInputStream("input_video", new ImageFramePacket(imageFrame, new Timestamp(currentTimestamp))).AssertOk();
//Debug.Log(Time.timeAsDouble + " Added new packet to mediapipe graph");
yield return new WaitForEndOfFrame();
Mediapipe.NormalizedLandmarkList _poseLandmarks = null;
Mediapipe.NormalizedLandmarkList _leftHandLandmarks = null;
Mediapipe.NormalizedLandmarkList _rightHandLandmarks = null;
//Debug.Log("Extracting keypoints");
yield return new WaitUntil(() => { posestream.TryGetNext(out _poseLandmarks, false); return true;});
yield return new WaitUntil(() => { leftstream.TryGetNext(out _leftHandLandmarks, false); return true; });
yield return new WaitUntil(() => { rightstream.TryGetNext(out _rightHandLandmarks, false); return true; });
//Debug.Log(Time.timeAsDouble + " Retrieved landmarks ");
k.addLandmarks(_poseLandmarks, _leftHandLandmarks, _rightHandLandmarks);
}
}
/// <summary>
/// Coroutine which calls the sign predictor model
/// </summary>
/// <returns></returns>
private IEnumerator SignRecognitionCoroutine()
{
while (true)
{
List<List<float>> input = k.getAllKeypoints();
if (input != null)
{
//UnityEngine.Debug.Log("input: " + input.Count);
int frameCount = input.Count;
int keypoints_per_frame = input[0].Count;
// Create a tensor with the input
inputTensor = new Tensor(frameCount, keypoints_per_frame);
// Fill the tensor with the input
for (int i = 0; i < frameCount; i++)
{
for (int j = 0; j < keypoints_per_frame; j++)
{
inputTensor[i, j] = input[i][j];
}
}
int stepsPerFrame = 190;
enumerator = worker.StartManualSchedule(inputTensor);
int step = 0;
while (enumerator.MoveNext())
{
if (++step % stepsPerFrame == 0)
{
//Debug.Log(Time.timeAsDouble + " : " + step);
yield return null;
}
}
var output = worker.PeekOutput();
inputTensor.Dispose();
// Get the output as an array
float[] outputArray = output.ToReadOnlyArray();
// Calculate the softmax of the output
float max = outputArray.Max();
float[] softmaxedOutput = outputArray.Select(x => Mathf.Exp(x - max)).ToArray();
float sum = softmaxedOutput.Sum();
float[] softmaxedOutput2 = softmaxedOutput.Select(x => x / sum).ToArray();
// Get the index of the highest probability
int maxIndex = softmaxedOutput2.ToList().IndexOf(softmaxedOutput2.Max());
// Get the letter from the index
char letter = (char)(maxIndex + 65);
float accuracy = (Mathf.RoundToInt(softmaxedOutput2[maxIndex] * 100));
// Set the letterProbabilities, currently used by Courses
letterProbabilities = new Dictionary<char, float>();
for (int i = 0; i < softmaxedOutput2.Length; i++)
{
letterProbabilities.Add((char)(i + 65), softmaxedOutput2[i]);
}
}
else
{
// Wait until next frame
//Debug.Log(Time.timeAsDouble + "No landmarks!");
yield return null;
}
}
}
/// <summary>
/// Propper destruction on the Mediapipegraph
/// </summary>
private void OnDestroy()
{
if (_webCamTexture != null)
{
_webCamTexture.Stop();
}
if (_graph != null)
{
try
{
_graph.CloseInputStream("input_video").AssertOk();
_graph.WaitUntilDone().AssertOk();
}
finally
{
_graph.Dispose();
}
}
// inputTensor must still be disposed, if it exists
inputTensor?.Dispose();
worker.Dispose();
}
}
}