using System.Collections.Generic; using System.Globalization; using System.Linq; using System.Text; using UnityEditor; #if UNITY_2020_2_OR_NEWER using UnityEditor.AssetImporters; using UnityEditor.Experimental.AssetImporters; #else using UnityEditor.Experimental.AssetImporters; #endif using UnityEngine; using System; using System.IO; using System.Reflection; using Unity.Barracuda.ONNX; using ImportMode=Unity.Barracuda.ONNX.ONNXModelConverter.ImportMode; using DataTypeMode=Unity.Barracuda.ONNX.ONNXModelConverter.DataTypeMode; namespace Unity.Barracuda.Editor { /// /// Asset Importer Editor of ONNX models /// [CustomEditor(typeof(ONNXModelImporter))] [CanEditMultipleObjects] public class ONNXModelImporterEditor : ScriptedImporterEditor { static PropertyInfo s_InspectorModeInfo; static ONNXModelImporterEditor() { s_InspectorModeInfo = typeof(SerializedObject).GetProperty("inspectorMode", BindingFlags.NonPublic | BindingFlags.Instance); } /// /// Scripted importer editor UI callback /// public override void OnInspectorGUI() { var onnxModelImporter = target as ONNXModelImporter; if (onnxModelImporter == null) return; InspectorMode inspectorMode = InspectorMode.Normal; if (s_InspectorModeInfo != null) inspectorMode = (InspectorMode)s_InspectorModeInfo.GetValue(assetSerializedObject); serializedObject.Update(); bool debugView = inspectorMode != InspectorMode.Normal; SerializedProperty iterator = serializedObject.GetIterator(); for (bool enterChildren = true; iterator.NextVisible(enterChildren); enterChildren = false) { if (iterator.propertyPath != "m_Script") EditorGUILayout.PropertyField(iterator, true); } // Additional options exposed from ImportMode SerializedProperty importModeProperty = serializedObject.FindProperty(nameof(onnxModelImporter.importMode)); bool skipMetadataImport = ((ImportMode)importModeProperty.intValue).HasFlag(ImportMode.SkipMetadataImport); if (EditorGUILayout.Toggle("Skip Metadata Import", skipMetadataImport) != skipMetadataImport) { importModeProperty.intValue ^= (int)ImportMode.SkipMetadataImport; } if (debugView) { importModeProperty.intValue = (int)(ImportMode)EditorGUILayout.EnumFlagsField("Import Mode", (ImportMode)importModeProperty.intValue); SerializedProperty weightsTypeMode = serializedObject.FindProperty(nameof(onnxModelImporter.weightsTypeMode)); SerializedProperty activationTypeMode = serializedObject.FindProperty(nameof(onnxModelImporter.activationTypeMode)); weightsTypeMode.intValue = (int)(DataTypeMode)EditorGUILayout.EnumPopup("Weights type", (DataTypeMode)weightsTypeMode.intValue); activationTypeMode.intValue = (int)(DataTypeMode)EditorGUILayout.EnumPopup("Activation type", (DataTypeMode)activationTypeMode.intValue); } else { if (onnxModelImporter.optimizeModel) EditorGUILayout.HelpBox("Model optimizations are on\nRemove and re-import model if you observe incorrect behavior", MessageType.Info); if (onnxModelImporter.importMode == ImportMode.Legacy) EditorGUILayout.HelpBox("Legacy importer is in use", MessageType.Warning); } serializedObject.ApplyModifiedProperties(); ApplyRevertGUI(); } } /// /// Asset Importer Editor of NNModel (the serialized file generated by ONNXModelImporter) /// [CustomEditor(typeof(NNModel))] public class NNModelEditor : UnityEditor.Editor { // Use a static store for the foldouts, so it applies to all inspectors static Dictionary s_UIHelperFoldouts = new Dictionary(); private Model m_Model; private List m_Inputs = new List(); private List m_InputsDesc = new List(); private List m_Outputs = new List(); private List m_OutputsDesc = new List(); private List m_Memories = new List(); private List m_MemoriesDesc = new List(); private List m_Layers = new List(); private List m_LayersDesc = new List(); private List m_Constants = new List(); private List m_ConstantsDesc = new List(); Dictionary m_Metadata = new Dictionary(); Vector2 m_MetadataScrollPosition = Vector2.zero; // warnings private Dictionary m_WarningsNeutral = new Dictionary(); private Dictionary m_WarningsInfo = new Dictionary(); private Dictionary m_WarningsWarning = new Dictionary(); private Dictionary m_WarningsError = new Dictionary(); private Vector2 m_WarningsNeutralScrollPosition = Vector2.zero; private Vector2 m_WarningsInfoScrollPosition = Vector2.zero; private Vector2 m_WarningsWarningScrollPosition = Vector2.zero; private Vector2 m_WarningsErrorScrollPosition = Vector2.zero; private long m_NumEmbeddedWeights; private long m_NumConstantWeights; private long m_TotalWeightsSizeInBytes; private Vector2 m_InputsScrollPosition = Vector2.zero; private Vector2 m_OutputsScrollPosition = Vector2.zero; private Vector2 m_MemoriesScrollPosition = Vector2.zero; private Vector2 m_LayerScrollPosition = Vector2.zero; private Vector2 m_ConstantScrollPosition = Vector2.zero; private const float k_Space = 5f; private Texture2D m_IconTexture; private Texture2D LoadIconTexture() { if (m_IconTexture != null) return m_IconTexture; string[] allCandidates = AssetDatabase.FindAssets(ONNXModelImporter.iconName); if (allCandidates.Length > 0) m_IconTexture = AssetDatabase.LoadAssetAtPath(AssetDatabase.GUIDToAssetPath(allCandidates[0]), typeof(Texture2D)) as Texture2D; return m_IconTexture; } /// /// Editor static preview rendering callback /// /// Asset path /// Child assets /// width /// height /// public override Texture2D RenderStaticPreview(string assetPath, UnityEngine.Object[] subAssets, int width, int height) { Texture2D icon = LoadIconTexture(); if (icon == null) return null; Texture2D tex = new Texture2D(width, height); EditorUtility.CopySerialized(icon, tex); return tex; } private void AddDimension(StringBuilder stringBuilder, string name, int value, bool lastDim=false) { string strValue = (value >= 1) ? value.ToString() : "*"; stringBuilder.AppendFormat("{0}:{1}", name, strValue); if (!lastDim) stringBuilder.Append(", "); } private string GetUIStringFromShape(int[] shape) { StringBuilder stringBuilder = new StringBuilder("shape: (", 50); if (shape.Length == 8) { bool is8D = (shape[0] > 1 || shape[1] > 1 || shape[3] > 1 || shape[4] > 1); if (is8D) AddDimension(stringBuilder, "s", shape[0]); if (is8D) AddDimension(stringBuilder, "r", shape[1]); AddDimension(stringBuilder, "n", shape[2]); if (is8D) AddDimension(stringBuilder, "t", shape[3]); if (is8D) AddDimension(stringBuilder, "d", shape[4]); AddDimension(stringBuilder, "h", shape[5]); AddDimension(stringBuilder, "w", shape[6]); AddDimension(stringBuilder, "c", shape[7], true); } else { UnityEngine.Debug.Assert(shape.Length == 4); AddDimension(stringBuilder, "n", shape[0]); AddDimension(stringBuilder, "h", shape[1]); AddDimension(stringBuilder, "w", shape[2]); AddDimension(stringBuilder, "c", shape[3], true); } stringBuilder.Append(")"); return stringBuilder.ToString(); } void OnEnable() { var nnModel = target as NNModel; if (nnModel == null) return; if (nnModel.modelData == null) return; m_Model = nnModel.GetDeserializedModel(); if (m_Model == null) return; m_Inputs = m_Model.inputs.Select(i => i.name).ToList(); m_InputsDesc = m_Model.inputs.Select(i => GetUIStringFromShape(i.shape)).ToList(); m_Outputs = m_Model.outputs.ToList(); bool allKnownInputShapes = true; var inputShapes = new Dictionary(); foreach (var i in m_Model.inputs) { allKnownInputShapes = allKnownInputShapes && ModelAnalyzer.IsInputShapeAcceptablyKnowForShapeInference(i); if (!allKnownInputShapes) break; inputShapes.Add(i.name, new TensorShape(i.shape)); } if (allKnownInputShapes) { m_OutputsDesc = m_Model.outputs.Select(i => { string output = "shape: (n:*, h:*, w:*, c:*)"; try { TensorShape shape; if (ModelAnalyzer.TryGetOutputTensorShape(m_Model, inputShapes, i, out shape)) output = GetUIStringFromShape(shape.ToArray()); } catch (Exception e) { Debug.LogError($"Unexpected error while evaluating model output {i}. {e}"); } return output; }).ToList(); } else { m_OutputsDesc = m_Model.outputs.Select(i => "shape: (n:*, h:*, w:*, c:*)").ToList(); } m_Memories = m_Model.memories.Select(i => i.input).ToList(); m_MemoriesDesc = m_Model.memories.Select(i => $"shape:{i.shape.ToString()} output:{i.output}").ToList(); var layers = m_Model.layers.Where(i => i.type != Layer.Type.Load); var constants = m_Model.layers.Where(i => i.type == Layer.Type.Load); m_Layers = layers.Select(i => i.type.ToString()).ToList(); m_LayersDesc = layers.Select(i => i.ToString()).ToList(); m_Constants = constants.Select(i => i.type.ToString()).ToList(); m_ConstantsDesc = constants.Select(i => i.ToString()).ToList(); m_NumEmbeddedWeights = layers.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length)); m_NumConstantWeights = constants.Sum(l => (long)l.datasets.Sum(ds => (long)ds.length)); // weights are not loaded for UI, recompute size m_TotalWeightsSizeInBytes = 0; for (var l = 0; l < m_Model.layers.Count; ++l) for (var d = 0; d < m_Model.layers[l].datasets.Length; ++d) m_TotalWeightsSizeInBytes += m_Model.layers[l].datasets[d].length * m_Model.layers[l].datasets[d].itemSizeInBytes; m_Metadata = new Dictionary(m_Model.Metadata); for (int i = 0; i < m_Model.Warnings.Count; i++) { var warning = m_Model.Warnings[i].LayerName; var warningDesc = m_Model.Warnings[i].Message; MessageType messageType = MessageType.Warning; if(warningDesc.StartsWith("MessageType")) { messageType = (MessageType)(warningDesc[12] - '0'); warningDesc = warningDesc.Substring(13); } switch (messageType) { case MessageType.None: m_WarningsNeutral[warning] = warningDesc; break; case MessageType.Info: m_WarningsInfo[warning] = warningDesc; break; case MessageType.Warning: m_WarningsWarning[warning] = warningDesc; break; case MessageType.Error: m_WarningsError[warning] = warningDesc; break; } } } private void OpenNNModelAsTempFileButton(NNModel nnModel) { if (nnModel == null) return; if (nnModel.modelData == null) return; if (GUILayout.Button("Open imported NN model as temp file")) { string tempPath = Application.temporaryCachePath; string filePath = Path.Combine(tempPath, nnModel.name); string filePathWithExtension = Path.ChangeExtension(filePath, "nn"); File.WriteAllBytes(filePathWithExtension, nnModel.modelData.Value); System.Diagnostics.Process.Start(filePathWithExtension); } } /// /// Editor UI rendering callback /// public override void OnInspectorGUI() { if (m_Model == null) return; // HACK: When inspector settings are applied and the file is re-imported there doesn't seem to be a clean way to // get a notification from Unity, so we detect this change var nnModel = target as NNModel; if (nnModel && m_Model != nnModel.GetDeserializedModel()) OnEnable(); // Model data changed underneath while inspector was active, so reload GUI.enabled = true; OpenNNModelAsTempFileButton(nnModel); GUILayout.Label($"Source: {m_Model.IrSource}"); GUILayout.Label($"Version: {m_Model.IrVersion}"); GUILayout.Label($"Producer Name: {m_Model.ProducerName}"); if (m_Metadata.Any()) { ListUIHelper($"Metadata {m_Metadata.Count}", m_Metadata.Keys.ToList(), m_Metadata.Values.ToList(), ref m_MetadataScrollPosition); } if(m_WarningsError.Any()) { ListUIHelper($"Errors {m_WarningsError.Count.ToString()}", m_WarningsError.Keys.ToList(), m_WarningsError.Values.ToList(), ref m_WarningsErrorScrollPosition); EditorGUILayout.HelpBox("Model contains errors. Behavior might be incorrect", MessageType.Error, true); } if(m_WarningsWarning.Any()) { ListUIHelper($"Warnings {m_WarningsWarning.Count.ToString()}", m_WarningsWarning.Keys.ToList(), m_WarningsWarning.Values.ToList(), ref m_WarningsWarningScrollPosition); EditorGUILayout.HelpBox("Model contains warnings. Behavior might be incorrect", MessageType.Warning, true); } if(m_WarningsInfo.Any()) { ListUIHelper($"Information: ", m_WarningsInfo.Keys.ToList(), m_WarningsInfo.Values.ToList(), ref m_WarningsInfoScrollPosition); EditorGUILayout.HelpBox("Model contains import information.", MessageType.Info, true); } if(m_WarningsNeutral.Any()) { ListUIHelper($"Comments: ", m_WarningsNeutral.Keys.ToList(), m_WarningsNeutral.Values.ToList(), ref m_WarningsNeutralScrollPosition); } var constantWeightInfo = m_Constants.Count > 0 ? $" using {m_NumConstantWeights:n0} weights" : ""; ListUIHelper($"Inputs ({m_Inputs.Count})", m_Inputs, m_InputsDesc, ref m_InputsScrollPosition); ListUIHelper($"Outputs ({m_Outputs.Count})", m_Outputs, m_OutputsDesc, ref m_OutputsScrollPosition); ListUIHelper($"Memories ({m_Memories.Count})", m_Memories, m_MemoriesDesc, ref m_MemoriesScrollPosition); ListUIHelper($"Layers ({m_Layers.Count} using {m_NumEmbeddedWeights:n0} embedded weights)", m_Layers, m_LayersDesc, ref m_LayerScrollPosition, m_Constants.Count == 0 ? 1.5f: 1f); ListUIHelper($"Constants ({m_Constants.Count}{constantWeightInfo})", m_Constants, m_ConstantsDesc, ref m_ConstantScrollPosition); GUILayout.Label($"Total weight size: {m_TotalWeightsSizeInBytes:n0} bytes"); } private static void ListUIHelper(string sectionTitle, IReadOnlyList names, IReadOnlyList descriptions, ref Vector2 scrollPosition, float maxHeightMultiplier = 1f) { int n = names.Count(); UnityEngine.Debug.Assert(descriptions.Count == n); if (descriptions.Count < n) return; GUILayout.Space(k_Space); if (!s_UIHelperFoldouts.TryGetValue(sectionTitle, out bool foldout)) foldout = true; foldout = EditorGUILayout.Foldout(foldout, sectionTitle, true, EditorStyles.foldoutHeader); s_UIHelperFoldouts[sectionTitle] = foldout; if (foldout) { // GUILayout.Label(sectionTitle, EditorStyles.boldLabel); float height = Mathf.Min(n * 20f + 2f, 150f * maxHeightMultiplier); if (n == 0) return; scrollPosition = GUILayout.BeginScrollView(scrollPosition, GUI.skin.box, GUILayout.MinHeight(height)); Event e = Event.current; float lineHeight = 16.0f; StringBuilder fullText = new StringBuilder(); fullText.Append(sectionTitle); fullText.AppendLine(); for (int i = 0; i < n; ++i) { string name = names[i]; string description = descriptions[i]; fullText.Append($"{name} {description}"); fullText.AppendLine(); } for (int i = 0; i < n; ++i) { Rect r = EditorGUILayout.GetControlRect(false, lineHeight); string name = names[i]; string description = descriptions[i]; // Context menu, "Copy" if (e.type == EventType.ContextClick && r.Contains(e.mousePosition)) { e.Use(); var menu = new GenericMenu(); // need to copy current value to be used in delegate // (C# closures close over variables, not their values) menu.AddItem(new GUIContent($"Copy current line"), false, delegate { EditorGUIUtility.systemCopyBuffer = $"{name} {description}"; }); menu.AddItem(new GUIContent($"Copy section"), false, delegate { EditorGUIUtility.systemCopyBuffer = fullText.ToString(); }); menu.ShowAsContext(); } // Color even line for readability if (e.type == EventType.Repaint) { GUIStyle st = "CN EntryBackEven"; if ((i & 1) == 0) st.Draw(r, false, false, false, false); } // layer name on the right side Rect locRect = r; locRect.xMax = locRect.xMin; GUIContent gc = new GUIContent(name.ToString(CultureInfo.InvariantCulture)); // calculate size so we can left-align it Vector2 size = EditorStyles.miniBoldLabel.CalcSize(gc); locRect.xMax += size.x; GUI.Label(locRect, gc, EditorStyles.miniBoldLabel); locRect.xMax += 2; // message Rect msgRect = r; msgRect.xMin = locRect.xMax; GUI.Label(msgRect, new GUIContent(description.ToString(CultureInfo.InvariantCulture)), EditorStyles.miniLabel); } GUILayout.EndScrollView(); } } } }