Files
unity-application/Packages/com.unity.barracuda/Editor/ONNXModelImporterEditor.cs
2023-03-18 19:53:17 +00:00

462 lines
19 KiB
C#

using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Text;
using UnityEditor;
#if UNITY_2020_2_OR_NEWER
using UnityEditor.AssetImporters;
using UnityEditor.Experimental.AssetImporters;
#else
using UnityEditor.Experimental.AssetImporters;
#endif
using UnityEngine;
using System;
using System.IO;
using System.Reflection;
using Unity.Barracuda.ONNX;
using ImportMode=Unity.Barracuda.ONNX.ONNXModelConverter.ImportMode;
using DataTypeMode=Unity.Barracuda.ONNX.ONNXModelConverter.DataTypeMode;
namespace Unity.Barracuda.Editor
{
/// <summary>
/// Asset Importer Editor of ONNX models
/// </summary>
[CustomEditor(typeof(ONNXModelImporter))]
[CanEditMultipleObjects]
public class ONNXModelImporterEditor : ScriptedImporterEditor
{
static PropertyInfo s_InspectorModeInfo;
static ONNXModelImporterEditor()
{
s_InspectorModeInfo = typeof(SerializedObject).GetProperty("inspectorMode", BindingFlags.NonPublic | BindingFlags.Instance);
}
/// <summary>
/// Scripted importer editor UI callback
/// </summary>
public override void OnInspectorGUI()
{
var onnxModelImporter = target as ONNXModelImporter;
if (onnxModelImporter == null)
return;
InspectorMode inspectorMode = InspectorMode.Normal;
if (s_InspectorModeInfo != null)
inspectorMode = (InspectorMode)s_InspectorModeInfo.GetValue(assetSerializedObject);
serializedObject.Update();
bool debugView = inspectorMode != InspectorMode.Normal;
SerializedProperty iterator = serializedObject.GetIterator();
for (bool enterChildren = true; iterator.NextVisible(enterChildren); enterChildren = false)
{
if (iterator.propertyPath != "m_Script")
EditorGUILayout.PropertyField(iterator, true);
}
// Additional options exposed from ImportMode
SerializedProperty importModeProperty = serializedObject.FindProperty(nameof(onnxModelImporter.importMode));
bool skipMetadataImport = ((ImportMode)importModeProperty.intValue).HasFlag(ImportMode.SkipMetadataImport);
if (EditorGUILayout.Toggle("Skip Metadata Import", skipMetadataImport) != skipMetadataImport)
{
importModeProperty.intValue ^= (int)ImportMode.SkipMetadataImport;
}
if (debugView)
{
importModeProperty.intValue = (int)(ImportMode)EditorGUILayout.EnumFlagsField("Import Mode", (ImportMode)importModeProperty.intValue);
SerializedProperty weightsTypeMode = serializedObject.FindProperty(nameof(onnxModelImporter.weightsTypeMode));
SerializedProperty activationTypeMode = serializedObject.FindProperty(nameof(onnxModelImporter.activationTypeMode));
weightsTypeMode.intValue = (int)(DataTypeMode)EditorGUILayout.EnumPopup("Weights type", (DataTypeMode)weightsTypeMode.intValue);
activationTypeMode.intValue = (int)(DataTypeMode)EditorGUILayout.EnumPopup("Activation type", (DataTypeMode)activationTypeMode.intValue);
}
else
{
if (onnxModelImporter.optimizeModel)
EditorGUILayout.HelpBox("Model optimizations are on\nRemove and re-import model if you observe incorrect behavior", MessageType.Info);
if (onnxModelImporter.importMode == ImportMode.Legacy)
EditorGUILayout.HelpBox("Legacy importer is in use", MessageType.Warning);
}
serializedObject.ApplyModifiedProperties();
ApplyRevertGUI();
}
}
/// <summary>
/// Asset Importer Editor of NNModel (the serialized file generated by ONNXModelImporter)
/// </summary>
[CustomEditor(typeof(NNModel))]
public class NNModelEditor : UnityEditor.Editor
{
// Use a static store for the foldouts, so it applies to all inspectors
static Dictionary<string, bool> s_UIHelperFoldouts = new Dictionary<string, bool>();
private Model m_Model;
private List<string> m_Inputs = new List<string>();
private List<string> m_InputsDesc = new List<string>();
private List<string> m_Outputs = new List<string>();
private List<string> m_OutputsDesc = new List<string>();
private List<string> m_Memories = new List<string>();
private List<string> m_MemoriesDesc = new List<string>();
private List<string> m_Layers = new List<string>();
private List<string> m_LayersDesc = new List<string>();
private List<string> m_Constants = new List<string>();
private List<string> m_ConstantsDesc = new List<string>();
Dictionary<string, string> m_Metadata = new Dictionary<string, string>();
Vector2 m_MetadataScrollPosition = Vector2.zero;
// warnings
private Dictionary<string, string> m_WarningsNeutral = new Dictionary<string, string>();
private Dictionary<string, string> m_WarningsInfo = new Dictionary<string, string>();
private Dictionary<string, string> m_WarningsWarning = new Dictionary<string, string>();
private Dictionary<string, string> m_WarningsError = new Dictionary<string, string>();
private Vector2 m_WarningsNeutralScrollPosition = Vector2.zero;
private Vector2 m_WarningsInfoScrollPosition = Vector2.zero;
private Vector2 m_WarningsWarningScrollPosition = Vector2.zero;
private Vector2 m_WarningsErrorScrollPosition = Vector2.zero;
private long m_NumEmbeddedWeights;
private long m_NumConstantWeights;
private long m_TotalWeightsSizeInBytes;
private Vector2 m_InputsScrollPosition = Vector2.zero;
private Vector2 m_OutputsScrollPosition = Vector2.zero;
private Vector2 m_MemoriesScrollPosition = Vector2.zero;
private Vector2 m_LayerScrollPosition = Vector2.zero;
private Vector2 m_ConstantScrollPosition = Vector2.zero;
private const float k_Space = 5f;
private Texture2D m_IconTexture;
private Texture2D LoadIconTexture()
{
if (m_IconTexture != null)
return m_IconTexture;
string[] allCandidates = AssetDatabase.FindAssets(ONNXModelImporter.iconName);
if (allCandidates.Length > 0)
m_IconTexture = AssetDatabase.LoadAssetAtPath(AssetDatabase.GUIDToAssetPath(allCandidates[0]), typeof(Texture2D)) as Texture2D;
return m_IconTexture;
}
/// <summary>
/// Editor static preview rendering callback
/// </summary>
/// <param name="assetPath">Asset path</param>
/// <param name="subAssets">Child assets</param>
/// <param name="width">width</param>
/// <param name="height">height</param>
/// <returns></returns>
public override Texture2D RenderStaticPreview(string assetPath, UnityEngine.Object[] subAssets, int width, int height)
{
Texture2D icon = LoadIconTexture();
if (icon == null)
return null;
Texture2D tex = new Texture2D(width, height);
EditorUtility.CopySerialized(icon, tex);
return tex;
}
private void AddDimension(StringBuilder stringBuilder, string name, int value, bool lastDim=false)
{
string strValue = (value >= 1) ? value.ToString() : "*";
stringBuilder.AppendFormat("{0}:{1}", name, strValue);
if (!lastDim)
stringBuilder.Append(", ");
}
private string GetUIStringFromShape(int[] shape)
{
StringBuilder stringBuilder = new StringBuilder("shape: (", 50);
if (shape.Length == 8)
{
bool is8D = (shape[0] > 1 || shape[1] > 1 || shape[3] > 1 || shape[4] > 1);
if (is8D) AddDimension(stringBuilder, "s", shape[0]);
if (is8D) AddDimension(stringBuilder, "r", shape[1]);
AddDimension(stringBuilder, "n", shape[2]);
if (is8D) AddDimension(stringBuilder, "t", shape[3]);
if (is8D) AddDimension(stringBuilder, "d", shape[4]);
AddDimension(stringBuilder, "h", shape[5]);
AddDimension(stringBuilder, "w", shape[6]);
AddDimension(stringBuilder, "c", shape[7], true);
}
else
{
UnityEngine.Debug.Assert(shape.Length == 4);
AddDimension(stringBuilder, "n", shape[0]);
AddDimension(stringBuilder, "h", shape[1]);
AddDimension(stringBuilder, "w", shape[2]);
AddDimension(stringBuilder, "c", shape[3], true);
}
stringBuilder.Append(")");
return stringBuilder.ToString();
}
void OnEnable()
{
var nnModel = target as NNModel;
if (nnModel == null)
return;
if (nnModel.modelData == null)
return;
m_Model = nnModel.GetDeserializedModel();
if (m_Model == null)
return;
m_Inputs = m_Model.inputs.Select(i => i.name).ToList();
m_InputsDesc = m_Model.inputs.Select(i => GetUIStringFromShape(i.shape)).ToList();
m_Outputs = m_Model.outputs.ToList();
bool allKnownInputShapes = true;
var inputShapes = new Dictionary<string, TensorShape>();
foreach (var i in m_Model.inputs)
{
allKnownInputShapes = allKnownInputShapes && ModelAnalyzer.IsInputShapeAcceptablyKnowForShapeInference(i);
if (!allKnownInputShapes)
break;
inputShapes.Add(i.name, new TensorShape(i.shape));
}
if (allKnownInputShapes)
{
m_OutputsDesc = m_Model.outputs.Select(i => {
string output = "shape: (n:*, h:*, w:*, c:*)";
try
{
TensorShape shape;
if (ModelAnalyzer.TryGetOutputTensorShape(m_Model, inputShapes, i, out shape))
output = GetUIStringFromShape(shape.ToArray());
}
catch (Exception e)
{
Debug.LogError($"Unexpected error while evaluating model output {i}. {e}");
}
return output; }).ToList();
}
else
{
m_OutputsDesc = m_Model.outputs.Select(i => "shape: (n:*, h:*, w:*, c:*)").ToList();
}
m_Memories = m_Model.memories.Select(i => i.input).ToList();
m_MemoriesDesc = m_Model.memories.Select(i => $"shape:{i.shape.ToString()} output:{i.output}").ToList();
var layers = m_Model.layers.Where(i => i.type != Layer.Type.Load);
var constants = m_Model.layers.Where(i => i.type == Layer.Type.Load);
m_Layers = layers.Select(i => i.type.ToString()).ToList();
m_LayersDesc = layers.Select(i => i.ToString()).ToList();
m_Constants = constants.Select(i => i.type.ToString()).ToList();
m_ConstantsDesc = constants.Select(i => i.ToString()).ToList();
m_NumEmbeddedWeights = layers.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length));
m_NumConstantWeights = constants.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length));
// weights are not loaded for UI, recompute size
m_TotalWeightsSizeInBytes = 0;
for (var l = 0; l < m_Model.layers.Count; ++l)
for (var d = 0; d < m_Model.layers[l].datasets.Length; ++d)
m_TotalWeightsSizeInBytes += m_Model.layers[l].datasets[d].length * m_Model.layers[l].datasets[d].itemSizeInBytes;
m_Metadata = new Dictionary<string, string>(m_Model.Metadata);
for (int i = 0; i < m_Model.Warnings.Count; i++)
{
var warning = m_Model.Warnings[i].LayerName;
var warningDesc = m_Model.Warnings[i].Message;
MessageType messageType = MessageType.Warning;
if(warningDesc.StartsWith("MessageType"))
{
messageType = (MessageType)(warningDesc[12] - '0');
warningDesc = warningDesc.Substring(13);
}
switch (messageType)
{
case MessageType.None:
m_WarningsNeutral[warning] = warningDesc;
break;
case MessageType.Info:
m_WarningsInfo[warning] = warningDesc;
break;
case MessageType.Warning:
m_WarningsWarning[warning] = warningDesc;
break;
case MessageType.Error:
m_WarningsError[warning] = warningDesc;
break;
}
}
}
private void OpenNNModelAsTempFileButton(NNModel nnModel)
{
if (nnModel == null)
return;
if (nnModel.modelData == null)
return;
if (GUILayout.Button("Open imported NN model as temp file"))
{
string tempPath = Application.temporaryCachePath;
string filePath = Path.Combine(tempPath, nnModel.name);
string filePathWithExtension = Path.ChangeExtension(filePath, "nn");
File.WriteAllBytes(filePathWithExtension, nnModel.modelData.Value);
System.Diagnostics.Process.Start(filePathWithExtension);
}
}
/// <summary>
/// Editor UI rendering callback
/// </summary>
public override void OnInspectorGUI()
{
if (m_Model == null)
return;
// HACK: When inspector settings are applied and the file is re-imported there doesn't seem to be a clean way to
// get a notification from Unity, so we detect this change
var nnModel = target as NNModel;
if (nnModel && m_Model != nnModel.GetDeserializedModel())
OnEnable(); // Model data changed underneath while inspector was active, so reload
GUI.enabled = true;
OpenNNModelAsTempFileButton(nnModel);
GUILayout.Label($"Source: {m_Model.IrSource}");
GUILayout.Label($"Version: {m_Model.IrVersion}");
GUILayout.Label($"Producer Name: {m_Model.ProducerName}");
if (m_Metadata.Any())
{
ListUIHelper($"Metadata {m_Metadata.Count}",
m_Metadata.Keys.ToList(), m_Metadata.Values.ToList(), ref m_MetadataScrollPosition);
}
if(m_WarningsError.Any())
{
ListUIHelper($"Errors {m_WarningsError.Count.ToString()}", m_WarningsError.Keys.ToList(), m_WarningsError.Values.ToList(), ref m_WarningsErrorScrollPosition);
EditorGUILayout.HelpBox("Model contains errors. Behavior might be incorrect", MessageType.Error, true);
}
if(m_WarningsWarning.Any())
{
ListUIHelper($"Warnings {m_WarningsWarning.Count.ToString()}", m_WarningsWarning.Keys.ToList(), m_WarningsWarning.Values.ToList(), ref m_WarningsWarningScrollPosition);
EditorGUILayout.HelpBox("Model contains warnings. Behavior might be incorrect", MessageType.Warning, true);
}
if(m_WarningsInfo.Any())
{
ListUIHelper($"Information: ", m_WarningsInfo.Keys.ToList(), m_WarningsInfo.Values.ToList(), ref m_WarningsInfoScrollPosition);
EditorGUILayout.HelpBox("Model contains import information.", MessageType.Info, true);
}
if(m_WarningsNeutral.Any())
{
ListUIHelper($"Comments: ", m_WarningsNeutral.Keys.ToList(), m_WarningsNeutral.Values.ToList(), ref m_WarningsNeutralScrollPosition);
}
var constantWeightInfo = m_Constants.Count > 0 ? $" using {m_NumConstantWeights:n0} weights" : "";
ListUIHelper($"Inputs ({m_Inputs.Count})", m_Inputs, m_InputsDesc, ref m_InputsScrollPosition);
ListUIHelper($"Outputs ({m_Outputs.Count})", m_Outputs, m_OutputsDesc, ref m_OutputsScrollPosition);
ListUIHelper($"Memories ({m_Memories.Count})", m_Memories, m_MemoriesDesc, ref m_MemoriesScrollPosition);
ListUIHelper($"Layers ({m_Layers.Count} using {m_NumEmbeddedWeights:n0} embedded weights)", m_Layers, m_LayersDesc, ref m_LayerScrollPosition, m_Constants.Count == 0 ? 1.5f: 1f);
ListUIHelper($"Constants ({m_Constants.Count}{constantWeightInfo})", m_Constants, m_ConstantsDesc, ref m_ConstantScrollPosition);
GUILayout.Label($"Total weight size: {m_TotalWeightsSizeInBytes:n0} bytes");
}
private static void ListUIHelper(string sectionTitle, IReadOnlyList<string> names, IReadOnlyList<string> descriptions, ref Vector2 scrollPosition, float maxHeightMultiplier = 1f)
{
int n = names.Count();
UnityEngine.Debug.Assert(descriptions.Count == n);
if (descriptions.Count < n)
return;
GUILayout.Space(k_Space);
if (!s_UIHelperFoldouts.TryGetValue(sectionTitle, out bool foldout))
foldout = true;
foldout = EditorGUILayout.Foldout(foldout, sectionTitle, true, EditorStyles.foldoutHeader);
s_UIHelperFoldouts[sectionTitle] = foldout;
if (foldout)
{
// GUILayout.Label(sectionTitle, EditorStyles.boldLabel);
float height = Mathf.Min(n * 20f + 2f, 150f * maxHeightMultiplier);
if (n == 0)
return;
scrollPosition = GUILayout.BeginScrollView(scrollPosition, GUI.skin.box, GUILayout.MinHeight(height));
Event e = Event.current;
float lineHeight = 16.0f;
StringBuilder fullText = new StringBuilder();
fullText.Append(sectionTitle);
fullText.AppendLine();
for (int i = 0; i < n; ++i)
{
string name = names[i];
string description = descriptions[i];
fullText.Append($"{name} {description}");
fullText.AppendLine();
}
for (int i = 0; i < n; ++i)
{
Rect r = EditorGUILayout.GetControlRect(false, lineHeight);
string name = names[i];
string description = descriptions[i];
// Context menu, "Copy"
if (e.type == EventType.ContextClick && r.Contains(e.mousePosition))
{
e.Use();
var menu = new GenericMenu();
// need to copy current value to be used in delegate
// (C# closures close over variables, not their values)
menu.AddItem(new GUIContent($"Copy current line"), false, delegate
{
EditorGUIUtility.systemCopyBuffer = $"{name} {description}";
});
menu.AddItem(new GUIContent($"Copy section"), false, delegate
{
EditorGUIUtility.systemCopyBuffer = fullText.ToString();
});
menu.ShowAsContext();
}
// Color even line for readability
if (e.type == EventType.Repaint)
{
GUIStyle st = "CN EntryBackEven";
if ((i & 1) == 0)
st.Draw(r, false, false, false, false);
}
// layer name on the right side
Rect locRect = r;
locRect.xMax = locRect.xMin;
GUIContent gc = new GUIContent(name.ToString(CultureInfo.InvariantCulture));
// calculate size so we can left-align it
Vector2 size = EditorStyles.miniBoldLabel.CalcSize(gc);
locRect.xMax += size.x;
GUI.Label(locRect, gc, EditorStyles.miniBoldLabel);
locRect.xMax += 2;
// message
Rect msgRect = r;
msgRect.xMin = locRect.xMax;
GUI.Label(msgRect, new GUIContent(description.ToString(CultureInfo.InvariantCulture)), EditorStyles.miniLabel);
}
GUILayout.EndScrollView();
}
}
}
}