462 lines
19 KiB
C#
462 lines
19 KiB
C#
using System.Collections.Generic;
|
|
using System.Globalization;
|
|
using System.Linq;
|
|
using System.Text;
|
|
using UnityEditor;
|
|
#if UNITY_2020_2_OR_NEWER
|
|
using UnityEditor.AssetImporters;
|
|
using UnityEditor.Experimental.AssetImporters;
|
|
#else
|
|
using UnityEditor.Experimental.AssetImporters;
|
|
#endif
|
|
using UnityEngine;
|
|
using System;
|
|
using System.IO;
|
|
using System.Reflection;
|
|
using Unity.Barracuda.ONNX;
|
|
using ImportMode=Unity.Barracuda.ONNX.ONNXModelConverter.ImportMode;
|
|
using DataTypeMode=Unity.Barracuda.ONNX.ONNXModelConverter.DataTypeMode;
|
|
|
|
namespace Unity.Barracuda.Editor
|
|
{
|
|
/// <summary>
|
|
/// Asset Importer Editor of ONNX models
|
|
/// </summary>
|
|
[CustomEditor(typeof(ONNXModelImporter))]
|
|
[CanEditMultipleObjects]
|
|
public class ONNXModelImporterEditor : ScriptedImporterEditor
|
|
{
|
|
static PropertyInfo s_InspectorModeInfo;
|
|
static ONNXModelImporterEditor()
|
|
{
|
|
s_InspectorModeInfo = typeof(SerializedObject).GetProperty("inspectorMode", BindingFlags.NonPublic | BindingFlags.Instance);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Scripted importer editor UI callback
|
|
/// </summary>
|
|
public override void OnInspectorGUI()
|
|
{
|
|
var onnxModelImporter = target as ONNXModelImporter;
|
|
if (onnxModelImporter == null)
|
|
return;
|
|
|
|
InspectorMode inspectorMode = InspectorMode.Normal;
|
|
if (s_InspectorModeInfo != null)
|
|
inspectorMode = (InspectorMode)s_InspectorModeInfo.GetValue(assetSerializedObject);
|
|
|
|
serializedObject.Update();
|
|
|
|
bool debugView = inspectorMode != InspectorMode.Normal;
|
|
SerializedProperty iterator = serializedObject.GetIterator();
|
|
for (bool enterChildren = true; iterator.NextVisible(enterChildren); enterChildren = false)
|
|
{
|
|
if (iterator.propertyPath != "m_Script")
|
|
EditorGUILayout.PropertyField(iterator, true);
|
|
}
|
|
|
|
// Additional options exposed from ImportMode
|
|
SerializedProperty importModeProperty = serializedObject.FindProperty(nameof(onnxModelImporter.importMode));
|
|
bool skipMetadataImport = ((ImportMode)importModeProperty.intValue).HasFlag(ImportMode.SkipMetadataImport);
|
|
if (EditorGUILayout.Toggle("Skip Metadata Import", skipMetadataImport) != skipMetadataImport)
|
|
{
|
|
importModeProperty.intValue ^= (int)ImportMode.SkipMetadataImport;
|
|
}
|
|
|
|
if (debugView)
|
|
{
|
|
importModeProperty.intValue = (int)(ImportMode)EditorGUILayout.EnumFlagsField("Import Mode", (ImportMode)importModeProperty.intValue);
|
|
|
|
SerializedProperty weightsTypeMode = serializedObject.FindProperty(nameof(onnxModelImporter.weightsTypeMode));
|
|
SerializedProperty activationTypeMode = serializedObject.FindProperty(nameof(onnxModelImporter.activationTypeMode));
|
|
weightsTypeMode.intValue = (int)(DataTypeMode)EditorGUILayout.EnumPopup("Weights type", (DataTypeMode)weightsTypeMode.intValue);
|
|
activationTypeMode.intValue = (int)(DataTypeMode)EditorGUILayout.EnumPopup("Activation type", (DataTypeMode)activationTypeMode.intValue);
|
|
}
|
|
else
|
|
{
|
|
if (onnxModelImporter.optimizeModel)
|
|
EditorGUILayout.HelpBox("Model optimizations are on\nRemove and re-import model if you observe incorrect behavior", MessageType.Info);
|
|
|
|
if (onnxModelImporter.importMode == ImportMode.Legacy)
|
|
EditorGUILayout.HelpBox("Legacy importer is in use", MessageType.Warning);
|
|
}
|
|
|
|
serializedObject.ApplyModifiedProperties();
|
|
|
|
ApplyRevertGUI();
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Asset Importer Editor of NNModel (the serialized file generated by ONNXModelImporter)
|
|
/// </summary>
|
|
[CustomEditor(typeof(NNModel))]
|
|
public class NNModelEditor : UnityEditor.Editor
|
|
{
|
|
// Use a static store for the foldouts, so it applies to all inspectors
|
|
static Dictionary<string, bool> s_UIHelperFoldouts = new Dictionary<string, bool>();
|
|
|
|
private Model m_Model;
|
|
private List<string> m_Inputs = new List<string>();
|
|
private List<string> m_InputsDesc = new List<string>();
|
|
private List<string> m_Outputs = new List<string>();
|
|
private List<string> m_OutputsDesc = new List<string>();
|
|
private List<string> m_Memories = new List<string>();
|
|
private List<string> m_MemoriesDesc = new List<string>();
|
|
private List<string> m_Layers = new List<string>();
|
|
private List<string> m_LayersDesc = new List<string>();
|
|
private List<string> m_Constants = new List<string>();
|
|
private List<string> m_ConstantsDesc = new List<string>();
|
|
|
|
Dictionary<string, string> m_Metadata = new Dictionary<string, string>();
|
|
Vector2 m_MetadataScrollPosition = Vector2.zero;
|
|
// warnings
|
|
private Dictionary<string, string> m_WarningsNeutral = new Dictionary<string, string>();
|
|
private Dictionary<string, string> m_WarningsInfo = new Dictionary<string, string>();
|
|
private Dictionary<string, string> m_WarningsWarning = new Dictionary<string, string>();
|
|
private Dictionary<string, string> m_WarningsError = new Dictionary<string, string>();
|
|
private Vector2 m_WarningsNeutralScrollPosition = Vector2.zero;
|
|
private Vector2 m_WarningsInfoScrollPosition = Vector2.zero;
|
|
private Vector2 m_WarningsWarningScrollPosition = Vector2.zero;
|
|
private Vector2 m_WarningsErrorScrollPosition = Vector2.zero;
|
|
|
|
|
|
private long m_NumEmbeddedWeights;
|
|
private long m_NumConstantWeights;
|
|
private long m_TotalWeightsSizeInBytes;
|
|
|
|
private Vector2 m_InputsScrollPosition = Vector2.zero;
|
|
private Vector2 m_OutputsScrollPosition = Vector2.zero;
|
|
private Vector2 m_MemoriesScrollPosition = Vector2.zero;
|
|
private Vector2 m_LayerScrollPosition = Vector2.zero;
|
|
private Vector2 m_ConstantScrollPosition = Vector2.zero;
|
|
private const float k_Space = 5f;
|
|
|
|
private Texture2D m_IconTexture;
|
|
private Texture2D LoadIconTexture()
|
|
{
|
|
if (m_IconTexture != null)
|
|
return m_IconTexture;
|
|
|
|
string[] allCandidates = AssetDatabase.FindAssets(ONNXModelImporter.iconName);
|
|
if (allCandidates.Length > 0)
|
|
m_IconTexture = AssetDatabase.LoadAssetAtPath(AssetDatabase.GUIDToAssetPath(allCandidates[0]), typeof(Texture2D)) as Texture2D;
|
|
|
|
return m_IconTexture;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Editor static preview rendering callback
|
|
/// </summary>
|
|
/// <param name="assetPath">Asset path</param>
|
|
/// <param name="subAssets">Child assets</param>
|
|
/// <param name="width">width</param>
|
|
/// <param name="height">height</param>
|
|
/// <returns></returns>
|
|
public override Texture2D RenderStaticPreview(string assetPath, UnityEngine.Object[] subAssets, int width, int height)
|
|
{
|
|
Texture2D icon = LoadIconTexture();
|
|
if (icon == null)
|
|
return null;
|
|
Texture2D tex = new Texture2D(width, height);
|
|
EditorUtility.CopySerialized(icon, tex);
|
|
return tex;
|
|
}
|
|
|
|
private void AddDimension(StringBuilder stringBuilder, string name, int value, bool lastDim=false)
|
|
{
|
|
string strValue = (value >= 1) ? value.ToString() : "*";
|
|
stringBuilder.AppendFormat("{0}:{1}", name, strValue);
|
|
if (!lastDim)
|
|
stringBuilder.Append(", ");
|
|
}
|
|
|
|
private string GetUIStringFromShape(int[] shape)
|
|
{
|
|
StringBuilder stringBuilder = new StringBuilder("shape: (", 50);
|
|
if (shape.Length == 8)
|
|
{
|
|
bool is8D = (shape[0] > 1 || shape[1] > 1 || shape[3] > 1 || shape[4] > 1);
|
|
if (is8D) AddDimension(stringBuilder, "s", shape[0]);
|
|
if (is8D) AddDimension(stringBuilder, "r", shape[1]);
|
|
AddDimension(stringBuilder, "n", shape[2]);
|
|
if (is8D) AddDimension(stringBuilder, "t", shape[3]);
|
|
if (is8D) AddDimension(stringBuilder, "d", shape[4]);
|
|
AddDimension(stringBuilder, "h", shape[5]);
|
|
AddDimension(stringBuilder, "w", shape[6]);
|
|
AddDimension(stringBuilder, "c", shape[7], true);
|
|
}
|
|
else
|
|
{
|
|
UnityEngine.Debug.Assert(shape.Length == 4);
|
|
AddDimension(stringBuilder, "n", shape[0]);
|
|
AddDimension(stringBuilder, "h", shape[1]);
|
|
AddDimension(stringBuilder, "w", shape[2]);
|
|
AddDimension(stringBuilder, "c", shape[3], true);
|
|
}
|
|
stringBuilder.Append(")");
|
|
return stringBuilder.ToString();
|
|
}
|
|
|
|
void OnEnable()
|
|
{
|
|
var nnModel = target as NNModel;
|
|
if (nnModel == null)
|
|
return;
|
|
if (nnModel.modelData == null)
|
|
return;
|
|
|
|
m_Model = nnModel.GetDeserializedModel();
|
|
if (m_Model == null)
|
|
return;
|
|
|
|
m_Inputs = m_Model.inputs.Select(i => i.name).ToList();
|
|
m_InputsDesc = m_Model.inputs.Select(i => GetUIStringFromShape(i.shape)).ToList();
|
|
m_Outputs = m_Model.outputs.ToList();
|
|
|
|
bool allKnownInputShapes = true;
|
|
var inputShapes = new Dictionary<string, TensorShape>();
|
|
foreach (var i in m_Model.inputs)
|
|
{
|
|
allKnownInputShapes = allKnownInputShapes && ModelAnalyzer.IsInputShapeAcceptablyKnowForShapeInference(i);
|
|
if (!allKnownInputShapes)
|
|
break;
|
|
inputShapes.Add(i.name, new TensorShape(i.shape));
|
|
}
|
|
if (allKnownInputShapes)
|
|
{
|
|
m_OutputsDesc = m_Model.outputs.Select(i => {
|
|
string output = "shape: (n:*, h:*, w:*, c:*)";
|
|
try
|
|
{
|
|
TensorShape shape;
|
|
if (ModelAnalyzer.TryGetOutputTensorShape(m_Model, inputShapes, i, out shape))
|
|
output = GetUIStringFromShape(shape.ToArray());
|
|
}
|
|
catch (Exception e)
|
|
{
|
|
Debug.LogError($"Unexpected error while evaluating model output {i}. {e}");
|
|
}
|
|
return output; }).ToList();
|
|
}
|
|
else
|
|
{
|
|
m_OutputsDesc = m_Model.outputs.Select(i => "shape: (n:*, h:*, w:*, c:*)").ToList();
|
|
}
|
|
|
|
m_Memories = m_Model.memories.Select(i => i.input).ToList();
|
|
m_MemoriesDesc = m_Model.memories.Select(i => $"shape:{i.shape.ToString()} output:{i.output}").ToList();
|
|
|
|
var layers = m_Model.layers.Where(i => i.type != Layer.Type.Load);
|
|
var constants = m_Model.layers.Where(i => i.type == Layer.Type.Load);
|
|
|
|
m_Layers = layers.Select(i => i.type.ToString()).ToList();
|
|
m_LayersDesc = layers.Select(i => i.ToString()).ToList();
|
|
m_Constants = constants.Select(i => i.type.ToString()).ToList();
|
|
m_ConstantsDesc = constants.Select(i => i.ToString()).ToList();
|
|
|
|
m_NumEmbeddedWeights = layers.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length));
|
|
m_NumConstantWeights = constants.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length));
|
|
|
|
// weights are not loaded for UI, recompute size
|
|
m_TotalWeightsSizeInBytes = 0;
|
|
for (var l = 0; l < m_Model.layers.Count; ++l)
|
|
for (var d = 0; d < m_Model.layers[l].datasets.Length; ++d)
|
|
m_TotalWeightsSizeInBytes += m_Model.layers[l].datasets[d].length * m_Model.layers[l].datasets[d].itemSizeInBytes;
|
|
|
|
m_Metadata = new Dictionary<string, string>(m_Model.Metadata);
|
|
|
|
for (int i = 0; i < m_Model.Warnings.Count; i++)
|
|
{
|
|
var warning = m_Model.Warnings[i].LayerName;
|
|
var warningDesc = m_Model.Warnings[i].Message;
|
|
MessageType messageType = MessageType.Warning;
|
|
if(warningDesc.StartsWith("MessageType"))
|
|
{
|
|
messageType = (MessageType)(warningDesc[12] - '0');
|
|
warningDesc = warningDesc.Substring(13);
|
|
}
|
|
|
|
switch (messageType)
|
|
{
|
|
case MessageType.None:
|
|
m_WarningsNeutral[warning] = warningDesc;
|
|
break;
|
|
case MessageType.Info:
|
|
m_WarningsInfo[warning] = warningDesc;
|
|
break;
|
|
case MessageType.Warning:
|
|
m_WarningsWarning[warning] = warningDesc;
|
|
break;
|
|
case MessageType.Error:
|
|
m_WarningsError[warning] = warningDesc;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
private void OpenNNModelAsTempFileButton(NNModel nnModel)
|
|
{
|
|
if (nnModel == null)
|
|
return;
|
|
if (nnModel.modelData == null)
|
|
return;
|
|
|
|
if (GUILayout.Button("Open imported NN model as temp file"))
|
|
{
|
|
string tempPath = Application.temporaryCachePath;
|
|
string filePath = Path.Combine(tempPath, nnModel.name);
|
|
string filePathWithExtension = Path.ChangeExtension(filePath, "nn");
|
|
File.WriteAllBytes(filePathWithExtension, nnModel.modelData.Value);
|
|
System.Diagnostics.Process.Start(filePathWithExtension);
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Editor UI rendering callback
|
|
/// </summary>
|
|
public override void OnInspectorGUI()
|
|
{
|
|
if (m_Model == null)
|
|
return;
|
|
|
|
// HACK: When inspector settings are applied and the file is re-imported there doesn't seem to be a clean way to
|
|
// get a notification from Unity, so we detect this change
|
|
var nnModel = target as NNModel;
|
|
if (nnModel && m_Model != nnModel.GetDeserializedModel())
|
|
OnEnable(); // Model data changed underneath while inspector was active, so reload
|
|
|
|
GUI.enabled = true;
|
|
OpenNNModelAsTempFileButton(nnModel);
|
|
GUILayout.Label($"Source: {m_Model.IrSource}");
|
|
GUILayout.Label($"Version: {m_Model.IrVersion}");
|
|
GUILayout.Label($"Producer Name: {m_Model.ProducerName}");
|
|
|
|
if (m_Metadata.Any())
|
|
{
|
|
ListUIHelper($"Metadata {m_Metadata.Count}",
|
|
m_Metadata.Keys.ToList(), m_Metadata.Values.ToList(), ref m_MetadataScrollPosition);
|
|
}
|
|
|
|
if(m_WarningsError.Any())
|
|
{
|
|
ListUIHelper($"Errors {m_WarningsError.Count.ToString()}", m_WarningsError.Keys.ToList(), m_WarningsError.Values.ToList(), ref m_WarningsErrorScrollPosition);
|
|
EditorGUILayout.HelpBox("Model contains errors. Behavior might be incorrect", MessageType.Error, true);
|
|
}
|
|
if(m_WarningsWarning.Any())
|
|
{
|
|
ListUIHelper($"Warnings {m_WarningsWarning.Count.ToString()}", m_WarningsWarning.Keys.ToList(), m_WarningsWarning.Values.ToList(), ref m_WarningsWarningScrollPosition);
|
|
EditorGUILayout.HelpBox("Model contains warnings. Behavior might be incorrect", MessageType.Warning, true);
|
|
}
|
|
if(m_WarningsInfo.Any())
|
|
{
|
|
ListUIHelper($"Information: ", m_WarningsInfo.Keys.ToList(), m_WarningsInfo.Values.ToList(), ref m_WarningsInfoScrollPosition);
|
|
EditorGUILayout.HelpBox("Model contains import information.", MessageType.Info, true);
|
|
}
|
|
if(m_WarningsNeutral.Any())
|
|
{
|
|
ListUIHelper($"Comments: ", m_WarningsNeutral.Keys.ToList(), m_WarningsNeutral.Values.ToList(), ref m_WarningsNeutralScrollPosition);
|
|
}
|
|
var constantWeightInfo = m_Constants.Count > 0 ? $" using {m_NumConstantWeights:n0} weights" : "";
|
|
ListUIHelper($"Inputs ({m_Inputs.Count})", m_Inputs, m_InputsDesc, ref m_InputsScrollPosition);
|
|
ListUIHelper($"Outputs ({m_Outputs.Count})", m_Outputs, m_OutputsDesc, ref m_OutputsScrollPosition);
|
|
ListUIHelper($"Memories ({m_Memories.Count})", m_Memories, m_MemoriesDesc, ref m_MemoriesScrollPosition);
|
|
ListUIHelper($"Layers ({m_Layers.Count} using {m_NumEmbeddedWeights:n0} embedded weights)", m_Layers, m_LayersDesc, ref m_LayerScrollPosition, m_Constants.Count == 0 ? 1.5f: 1f);
|
|
ListUIHelper($"Constants ({m_Constants.Count}{constantWeightInfo})", m_Constants, m_ConstantsDesc, ref m_ConstantScrollPosition);
|
|
|
|
GUILayout.Label($"Total weight size: {m_TotalWeightsSizeInBytes:n0} bytes");
|
|
}
|
|
|
|
private static void ListUIHelper(string sectionTitle, IReadOnlyList<string> names, IReadOnlyList<string> descriptions, ref Vector2 scrollPosition, float maxHeightMultiplier = 1f)
|
|
{
|
|
int n = names.Count();
|
|
UnityEngine.Debug.Assert(descriptions.Count == n);
|
|
if (descriptions.Count < n)
|
|
return;
|
|
|
|
GUILayout.Space(k_Space);
|
|
if (!s_UIHelperFoldouts.TryGetValue(sectionTitle, out bool foldout))
|
|
foldout = true;
|
|
|
|
foldout = EditorGUILayout.Foldout(foldout, sectionTitle, true, EditorStyles.foldoutHeader);
|
|
s_UIHelperFoldouts[sectionTitle] = foldout;
|
|
if (foldout)
|
|
{
|
|
// GUILayout.Label(sectionTitle, EditorStyles.boldLabel);
|
|
float height = Mathf.Min(n * 20f + 2f, 150f * maxHeightMultiplier);
|
|
if (n == 0)
|
|
return;
|
|
|
|
scrollPosition = GUILayout.BeginScrollView(scrollPosition, GUI.skin.box, GUILayout.MinHeight(height));
|
|
Event e = Event.current;
|
|
float lineHeight = 16.0f;
|
|
|
|
StringBuilder fullText = new StringBuilder();
|
|
fullText.Append(sectionTitle);
|
|
fullText.AppendLine();
|
|
for (int i = 0; i < n; ++i)
|
|
{
|
|
string name = names[i];
|
|
string description = descriptions[i];
|
|
fullText.Append($"{name} {description}");
|
|
fullText.AppendLine();
|
|
}
|
|
|
|
for (int i = 0; i < n; ++i)
|
|
{
|
|
Rect r = EditorGUILayout.GetControlRect(false, lineHeight);
|
|
|
|
string name = names[i];
|
|
string description = descriptions[i];
|
|
|
|
// Context menu, "Copy"
|
|
if (e.type == EventType.ContextClick && r.Contains(e.mousePosition))
|
|
{
|
|
e.Use();
|
|
var menu = new GenericMenu();
|
|
|
|
// need to copy current value to be used in delegate
|
|
// (C# closures close over variables, not their values)
|
|
menu.AddItem(new GUIContent($"Copy current line"), false, delegate
|
|
{
|
|
EditorGUIUtility.systemCopyBuffer = $"{name} {description}";
|
|
});
|
|
menu.AddItem(new GUIContent($"Copy section"), false, delegate
|
|
{
|
|
EditorGUIUtility.systemCopyBuffer = fullText.ToString();
|
|
});
|
|
menu.ShowAsContext();
|
|
}
|
|
|
|
// Color even line for readability
|
|
if (e.type == EventType.Repaint)
|
|
{
|
|
GUIStyle st = "CN EntryBackEven";
|
|
if ((i & 1) == 0)
|
|
st.Draw(r, false, false, false, false);
|
|
}
|
|
|
|
// layer name on the right side
|
|
Rect locRect = r;
|
|
locRect.xMax = locRect.xMin;
|
|
GUIContent gc = new GUIContent(name.ToString(CultureInfo.InvariantCulture));
|
|
|
|
// calculate size so we can left-align it
|
|
Vector2 size = EditorStyles.miniBoldLabel.CalcSize(gc);
|
|
locRect.xMax += size.x;
|
|
GUI.Label(locRect, gc, EditorStyles.miniBoldLabel);
|
|
locRect.xMax += 2;
|
|
|
|
// message
|
|
Rect msgRect = r;
|
|
msgRect.xMin = locRect.xMax;
|
|
GUI.Label(msgRect, new GUIContent(description.ToString(CultureInfo.InvariantCulture)), EditorStyles.miniLabel);
|
|
}
|
|
|
|
GUILayout.EndScrollView();
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|