using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; using UnityEngine; using UnityEngine.Assertions; using UnityEngine.Profiling; [assembly: InternalsVisibleTo("Unity.Barracuda.ONNX")] [assembly: InternalsVisibleTo("Unity.Barracuda.Editor")] namespace Unity.Barracuda { internal class ModelAnalyzer { public static string GetDefaultInputName(Model model) { bool modelHasOnlyOneInput = model.inputs.Count == 1; if (modelHasOnlyOneInput) return model.inputs[0].name; var memories = new HashSet(); foreach (var m in model.memories) memories.Add(m.input); // find the first unconnected input as a default model input var previousLayerNames = new HashSet(); foreach (var l in model.layers) { previousLayerNames.Add(l.name); bool layerDoesNotNeedInput = (l.type == Layer.Type.Load); if (layerDoesNotNeedInput) continue; foreach (var inputName in l.inputs) { bool inputIsUnconnected = !previousLayerNames.Contains(inputName); bool inputIsNotPartOfMemory = !memories.Contains(inputName); if (inputIsUnconnected && inputIsNotPartOfMemory) return inputName; } } return ""; } static public string GetDefaultOutputName(Model model) { if (model.outputs.Count == 1) return model.outputs[0]; if (model.layers.Count > 0) { var lastLayer = model.layers[model.layers.Count - 1]; return lastLayer.name; } return ""; } public static TensorShape?[] ListTemporaryTensorShapes(Model model, IDictionary inputShapes) { IDictionary shapesByName; return ListTemporaryTensorShapes(model, inputShapes, out shapesByName); } public static TensorShape?[] ListTemporaryTensorShapes(Model model, IDictionary inputShapes, out IDictionary shapesByName) { Profiler.BeginSample ("Barracuda.ListTemporaryTensorShapes"); var shapes = new List(); shapesByName = new Dictionary(); foreach (var entry in inputShapes) shapesByName.Add(entry.Key, entry.Value); TensorShape? Xn; shapesByName.TryGetValue(GetDefaultInputName(model), out Xn); // default input TensorShape? O = Xn; foreach (var l in model.layers) { if (l.inputs.Length > 0 && shapesByName.TryGetValue(l.inputs[0], out TensorShape? xShape)) Xn = xShape; else Xn = O; // previous output is used, if-and-only-if layer has no explicit inputs if (Xn == null) { shapes.Add(Xn); shapesByName.Add(l.name, Xn); continue; } TensorShape X = Xn.Value; if (l.type == Layer.Type.Dense) { Assert.IsNotNull(l.datasets); var W = l.datasets[0].shape; O = new TensorShape(X.flatHeight, W.flatWidth); } else if (l.type == Layer.Type.Dense3) { Assert.IsNotNull(l.datasets); var W = l.datasets[0].shape; O = new TensorShape(X.batch, 1, W.channels, X.channels); } else if (l.type == Layer.Type.MatMul) { if (!shapesByName.ContainsKey(l.inputs[1]) || shapesByName[l.inputs[1]] == null) { O = null; break; } var Y = shapesByName[l.inputs[1]].Value; int rankX; int rankY; List onnxXshape; List onnxYshape; if (l.pool == null || l.pool.Length == 0) { LegacyGetXYRanks(X, Y, out rankX, out rankY); } else { rankX = l.pool[0]; rankY = l.pool[1]; } onnxXshape = Compiler.IRShapeInferenceHelper.ShapeInference.BarracudaShapeToOnnxLayout(X, rankX); onnxYshape = Compiler.IRShapeInferenceHelper.ShapeInference.BarracudaShapeToOnnxLayout(Y, rankY); int rankO = Math.Max(rankX, rankY); // pad 1 on front of shape to both be rankO shape for (int i = 0; i < (rankX - rankY); i++) onnxYshape.Insert(0, 1); for (int i = 0; i < (rankY - rankX); i++) onnxXshape.Insert(0, 1); if (rankO == 2) O = new TensorShape(onnxXshape[0], 1, 1, onnxYshape[1]); else if (rankO == 3) O = new TensorShape(Math.Max(onnxXshape[0], onnxYshape[0]), 1, onnxYshape[2], onnxXshape[1]); else O = new TensorShape(Math.Max(onnxXshape[0], onnxYshape[0]), onnxXshape[2], onnxYshape[3], Math.Max(onnxXshape[1], onnxYshape[1])); } else if ( l.type == Layer.Type.Conv2D || l.type == Layer.Type.Conv3D || l.type == Layer.Type.DepthwiseConv2D) { var K = l.datasets[0].shape; Assert.IsNotNull(l.stride); Assert.IsNotNull(l.pad); var pad = X.AdjustPadToKernel(K, l.stride, l.pad); O = X.ApplyKernel(K, l.stride, pad); } else if ( l.type == Layer.Type.Conv2DTrans) { var K = l.datasets[0].shape; Assert.IsNotNull(l.stride); Assert.IsNotNull(l.pad); // pool size is treated as output_adjustment aka output_padding here var outputAdjustment = l.pool; var pad = X.AdjustPadToKernel(K, l.stride, l.pad); O = X.ApplyKernelInverse(K, l.stride, pad, outputAdjustment); } else if ( l.type == Layer.Type.Upsample2D) { if(l.pool.Length != 2) { O = null; } else { // pool size is treated as upsample coefficient here Assert.IsNotNull(l.pool); Assert.AreEqual(l.pool.Length, 2); O = new TensorShape(X.batch, X.height * l.pool[1], X.width * l.pool[0], X.channels); } } else if ( l.type == Layer.Type.Upsample3D) { if(l.pool.Length != 2) { O = null; } else { // pool size is treated as upsample coefficient here Assert.IsNotNull(l.pool); Assert.AreEqual(l.pool.Length, 3); O = new TensorShape(1,1,X.batch, 1, X.depth * l.pool[2], X.height * l.pool[1], X.width * l.pool[0], X.channels); } } else if ( l.type == Layer.Type.Resample2D) { if(l.pool.Length != 2) { O = null; } else { // pool is treated as resample size here var size = l.pool; Assert.IsNotNull(size); Assert.AreEqual(size.Length, 2); O = new TensorShape(X.batch, size[1], size[0], X.channels); } } else if ( l.type == Layer.Type.DepthToSpace) { // pool size is treated as blocksize here Assert.IsNotNull(l.pool); Assert.AreEqual(l.pool.Length, 2); Assert.AreEqual(X.channels % (l.pool[0] * l.pool[1]), 0); O = new TensorShape(X.batch, X.height * l.pool[1], X.width * l.pool[0], X.channels / (l.pool[0] * l.pool[1])); } else if ( l.type == Layer.Type.SpaceToDepth) { // pool size is treated as blocksize here Assert.IsNotNull(l.pool); Assert.AreEqual(l.pool.Length, 2); O = new TensorShape(X.batch, X.height / l.pool[1], X.width / l.pool[0], X.channels * (l.pool[0] * l.pool[1])); } else if ( l.type == Layer.Type.MaxPool2D || l.type == Layer.Type.AvgPool2D) { Assert.IsNotNull(l.pool); Assert.IsNotNull(l.stride); Assert.IsNotNull(l.pad); var pad = X.AdjustPadToPool(l.pool, l.stride, l.pad); O = X.ApplyPool(l.pool, l.stride, pad); } else if ( l.type == Layer.Type.GlobalMaxPool2D || l.type == Layer.Type.GlobalAvgPool2D) { O = new TensorShape(X.batch, 1, 1, X.channels); } else if (l.type == Layer.Type.Border3D) { Assert.IsNotNull(l.pad); // legacy support if (l.pad.Length == 6) X = X.ApplyBorder(new[] { l.pad[0], l.pad[1], l.pad[2], 0, l.pad[3], l.pad[4], l.pad[5], 0 }); else O = X.ApplyBorder(l.pad); } else if ( l.type == Layer.Type.Border2D || l.type == Layer.Type.Pad2DReflect || l.type == Layer.Type.Pad2DSymmetric || l.type == Layer.Type.Pad2DEdge) { Assert.IsNotNull(l.pad); // legacy support if (l.pad.Length == 4) X = X.ApplyBorder(new[] { l.pad[0], l.pad[1], 0, l.pad[2], l.pad[3], 0 }); else O = X.ApplyBorder(l.pad); } else if ( l.type == Layer.Type.Conv3D || l.type == Layer.Type.Conv3DTrans || l.type == Layer.Type.Upsample3D || l.type == Layer.Type.MaxPool3D || l.type == Layer.Type.AvgPool3D || l.type == Layer.Type.GlobalMaxPool3D || l.type == Layer.Type.GlobalAvgPool3D || l.type == Layer.Type.Border3D) { throw new NotImplementedException(); } else if ( l.type == Layer.Type.RandomNormal || l.type == Layer.Type.RandomUniform) { Assert.IsNotNull(l.pool); // pool size is treated as shape constant, if not empty // otherwise shape of the previous tensor is used if (l.pool.Length > 0) O = new TensorShape(l.pool); else O = X; } else if (l.type == Layer.Type.ConstantOfShape) { if(l.axis != 1) O = null; else O = X; } else if ( l.type == Layer.Type.Multinomial) { Assert.IsNotNull(l.pool); Assert.AreEqual(l.pool.Length, 1); O = new TensorShape(X.batch, l.pool[0]); } else if ( l.type == Layer.Type.OneHot) { Assert.IsNotNull(l.pool); Assert.AreEqual(l.pool.Length, 1); int depth = l.pool[0]; int inputRank = l.axis; inputRank = inputRank < 0 ? X.dimensions : inputRank; if (inputRank == 1) O = new TensorShape(X.flatHeight, depth); else if (inputRank == 2) O = new TensorShape(X.flatHeight, 1, depth, X.flatWidth); else O = new TensorShape(X.batch, X.height, depth, X.channels); } else if (l.type == Layer.Type.RoiAlign) { Assert.IsNotNull(l.pool); Assert.AreEqual(l.pool.Length, 2); if (shapesByName.TryGetValue(l.inputs[1], out TensorShape? shape) && shape != null) { int batches = shape.Value.flatHeight; O = new TensorShape(batches, l.pool[0], l.pool[1], X.channels); } else O = null; } else if ( l.type == Layer.Type.Add || l.type == Layer.Type.Sub || l.type == Layer.Type.Mul || l.type == Layer.Type.Div || l.type == Layer.Type.Pow || l.type == Layer.Type.Min || l.type == Layer.Type.Max || l.type == Layer.Type.Mean|| l.type == Layer.Type.Greater || l.type == Layer.Type.GreaterEqual || l.type == Layer.Type.Less || l.type == Layer.Type.LessEqual || l.type == Layer.Type.Equal || l.type == Layer.Type.LogicalOr || l.type == Layer.Type.LogicalAnd || l.type == Layer.Type.LogicalXor || l.type == Layer.Type.Where) { // gather shapes by names var list = new List(l.inputs.Length); bool allShapesKnown = true; foreach (var i in l.inputs) { if (shapesByName.TryGetValue(i, out TensorShape? shape) && shape != null) list.Add(shape.Value); else allShapesKnown = false; } O = allShapesKnown ? TensorExtensions.Max(list.ToArray()) : default(TensorShape?); } else if ( l.type == Layer.Type.ReduceL1 || l.type == Layer.Type.ReduceL2 || l.type == Layer.Type.ReduceLogSum || l.type == Layer.Type.ReduceLogSumExp || l.type == Layer.Type.ReduceMax || l.type == Layer.Type.ReduceMean || l.type == Layer.Type.ReduceMin || l.type == Layer.Type.ReduceProd || l.type == Layer.Type.ReduceSum || l.type == Layer.Type.ReduceSumSquare || l.type == Layer.Type.ArgMax || l.type == Layer.Type.ArgMin) { O = X.Reduce(l.axis); } else if ( l.type == Layer.Type.Flatten) { O = X.Flatten(); } else if ( l.type == Layer.Type.Reshape) { // pool size is treated as the shape, if not empty var size = l.pool; Assert.IsNotNull(size); if (size.Length == 0 && l.inputs.Length > 1) { switch (l.axis) { // Legacy - use the shape of the input tensor as the shape case -1: if (shapesByName.TryGetValue(l.inputs[1], out TensorShape? shape)) size = shape.Value.ToArray(); break; // Use the tensor values as the shape; Calculated at runtime case 1: O = null; break; } if (O == null) break; } Assert.IsTrue( (size.Length == 4) || (size.Length == 8)); O = X.Reshape(size); } else if ( l.type == Layer.Type.Expand) { // pool size is treated as new shape var newShape = l.pool; Assert.IsNotNull(newShape); Assert.IsTrue(newShape.Length == 8 || newShape.Length == 4); O = new TensorShape(newShape); } else if ( l.type == Layer.Type.Transpose) { var permutations = l.pool; if (permutations == null) O = new TensorShape(X.flatWidth, X.flatHeight); else { Assert.IsTrue(permutations.Length == 8 || permutations.Length == 4); O = X.Permute(permutations); } } else if ( l.type == Layer.Type.Gather) { if (!shapesByName.TryGetValue(l.inputs[0], out TensorShape? input0Shape) || input0Shape == null || !shapesByName.TryGetValue(l.inputs[1], out TensorShape? input1Shape) || input1Shape == null) { O = null; break; } int[] shape = input0Shape.Value.ToArray(); shape[l.axis] = input1Shape.Value.length; O = new TensorShape(shape); if (l.pool != null && l.pool.Length == 2 && l.pool[1] > 1) { int xRank = l.pool[0]; int indicesRank = l.pool[1]; var oShape = Compiler.IRShapeInferenceHelper.ShapeInference.BarracudaShapeToList(O.Value, xRank); var indicesShape = Compiler.IRShapeInferenceHelper.ShapeInference.BarracudaShapeToList(input1Shape.Value, indicesRank); int axis = Compiler.IRShapeInferenceHelper.ShapeInference.BarracudaAxisToTensor(l.axis, xRank); oShape.InsertRange(axis, indicesShape); oShape.RemoveAt(axis + indicesShape.Count); O = (O.Value).Reshape(Compiler.IRShapeInferenceHelper.ShapeInference.BarracudaLayoutToTensorShapeLayout(oShape.ToArray())); // rank 2 -> 3 if (xRank == 2 && oShape.Count == 3) O = (O.Value).Permute(new int[] { 0, 1, 3, 2 }); } } else if (l.type == Layer.Type.ScatterND) { O = X; } else if ( l.type == Layer.Type.Squeeze || l.type == Layer.Type.Unsqueeze) { O = X; } else if ( l.type == Layer.Type.Concat) { // gather shapes by names var list = new List(l.inputs.Length); bool allShapesKnown = true; foreach (var i in l.inputs) { if (!shapesByName.TryGetValue(i, out var shape) || shape == null) { allShapesKnown = false; continue; } list.Add(shape.Value); } O = allShapesKnown ? TensorExtensions.Concat(list.ToArray(), l.axis) : default(TensorShape?); } else if ( l.type == Layer.Type.StridedSlice) { Assert.IsNotNull(l.pad); Assert.IsNotNull(l.pool); Assert.IsNotNull(l.stride); O = X.ApplyStridedSlice(l.pad, l.pool, l.stride); } else if ( l.type == Layer.Type.Tile) { // pool size is treated as tiling coefficient here Assert.IsNotNull(l.pool); var scale = l.pool; O = X.Scale(scale); } else if ( l.type == Layer.Type.Load) { O = l.datasets[0].shape; } else if (// elementwise operations l.type == Layer.Type.Nop || l.type == Layer.Type.Activation || l.type == Layer.Type.ScaleBias || l.type == Layer.Type.Normalization || l.type == Layer.Type.LRN || l.type == Layer.Type.Dropout || l.type == Layer.Type.LogicalNot || l.type == Layer.Type.Sign) { // works in place, keeps the same shape size O = X; } else if ( l.type == Layer.Type.TopKIndices || l.type == Layer.Type.TopKValues || l.type == Layer.Type.NonMaxSuppression || l.type == Layer.Type.LSTM || l.type == Layer.Type.NonZero) { // Calculated at runtime O = null; } else if (l.type == Layer.Type.Shape) { int shapeRank = l.axis > 0 ? 1 : X.length; O = new TensorShape(shapeRank, 1, 1, 1); } else if ( l.type == Layer.Type.Conv3D || l.type == Layer.Type.Conv3DTrans || l.type == Layer.Type.Upsample3D || l.type == Layer.Type.MaxPool3D || l.type == Layer.Type.AvgPool3D || l.type == Layer.Type.GlobalMaxPool3D || l.type == Layer.Type.GlobalAvgPool3D || l.type == Layer.Type.Border3D) { throw new NotImplementedException("3D operations are not implemented yet!"); } else { throw new NotImplementedException($"Layer type {l.type} needs to be explicitly handled"); } shapes.Add(O); shapesByName.Add(l.name, O); } Profiler.EndSample(); return shapes.ToArray(); } // TODO: Remove when the legacy importer / code path is no longer needed (i.e. when pool is always set) public static void LegacyGetXYRanks(TensorShape X, TensorShape Y, out int rankX, out int rankY) { // ONNX rank 2 : N,C => N,1,1,C // rank 3 : one must be N C W, (batches = N) => N, 1, W, C // rank 4 : one must be N C H W, (batches = N * C) => N H W C // X and Y can be different ranks var onnxXshape = new List { X.batch, X.channels, X.height, X.width }; if (X.height == 1) onnxXshape = new List { X.batch, X.channels, X.width, 1 }; var onnxYshape = new List { Y.batch, Y.channels, Y.height, Y.width }; if (Y.height == 1) onnxYshape = new List { Y.batch, Y.channels, Y.width, 1 }; rankX = 0; for (int i = 3; i >= 0; i--) { if (onnxXshape[i] != 1) { rankX = i + 1; break; } } rankY = 0; for (int i = 3; i >= 0; i--) { if (onnxYshape[i] != 1) { rankY = i + 1; break; } } } public static bool TryGetOutputTensorShape(Model model, IDictionary inputShapes, string output, out TensorShape shape) { shape = new TensorShape(); IDictionary shapesByName; ListTemporaryTensorShapes(model, inputShapes, out shapesByName); TensorShape? dynamicShape; bool found = shapesByName.TryGetValue(output, out dynamicShape) && dynamicShape != null; if (found) shape = dynamicShape.Value; return found; } public static bool TryGetOutputTensorShape(Model model, string output, out TensorShape shape) { var inputShapes = new Dictionary(); foreach (var i in model.inputs) inputShapes.Add(i.name, new TensorShape(i.shape)); return TryGetOutputTensorShape(model, inputShapes, output, out shape); } public static bool FindLayerByName(Model model, string name, out Layer layer) { layer = new Layer("",Layer.Type.Nop); foreach (var l in model.layers) { if (l.name == name) { layer = l; return true; } } return false; } public static HashSet FindLayersThatRequireStorage(Model model) { var allInputsExceptFromPreviousLayer = new HashSet(); Layer prevLayer = null; foreach (var layer in model.layers) { foreach (var input in layer.inputs) if (prevLayer != null && input != prevLayer.name) allInputsExceptFromPreviousLayer.Add(input); prevLayer = layer; } var allOutputs = new HashSet(); foreach (var output in model.outputs) allOutputs.Add(output); foreach (var memory in model.memories) allOutputs.Add(memory.output); allOutputs.Add(GetDefaultOutputName(model)); var requireStorage = new HashSet(); foreach (var layer in model.layers) { // loading constant tensor requires storage if (layer.type == Layer.Type.Load) requireStorage.Add(layer); // @TBD: implement safety check that ensures Nop never has input // otherwise it has to be treated as Load operation if (layer.type == Layer.Type.Nop) requireStorage.Add(layer); if (allInputsExceptFromPreviousLayer.Contains(layer.name) || allOutputs.Contains(layer.name)) requireStorage.Add(layer); } return requireStorage; } public static HashSet FindUpstreamLayers(Model model, string[] outputs) { // TODO: replace with var layersByName = model.layers.ToDictionary(i => i.name, i => i); var layersByName = new Dictionary(); foreach (var l in model.layers) layersByName.Add(l.name, l); var connected = new HashSet(); var layersToVisit = new HashSet(); foreach (var o in outputs) if (layersByName.ContainsKey(o)) { layersToVisit.Add(layersByName[o]); connected.Add(layersByName[o]); } while (layersToVisit.Count > 0) { var visitNext = new HashSet(); foreach (var l in layersToVisit) foreach (var i in l.inputs) if (layersByName.ContainsKey(i)) { visitNext.Add(layersByName[i]); connected.Add(layersByName[i]); } layersToVisit = visitNext; } return connected; } public static TensorShape FindLargestNecessaryTensorShape(Model model, IDictionary inputShapes) { Profiler.BeginSample ("Barracuda.FindLargestNecessaryTensorShape"); var shapes = ListTemporaryTensorShapes(model, inputShapes); var maxTensorShape = new TensorShape(1,1,1,1); foreach (var X in shapes) if (X?.length > maxTensorShape.length) maxTensorShape = X.Value; Profiler.EndSample (); return maxTensorShape; } public static TensorShape FindLargestArgumentTensorShape(Model model) { TensorShape maxTensorShape = new TensorShape(1,1,1,1); foreach (var layer in model.layers) foreach (var arg in layer.datasets) if (arg.shape.length > maxTensorShape.length) maxTensorShape = arg.shape; return maxTensorShape; } public static string[] FindUnusedLayers(Model model) { var layerUsageByName = model.layers.ToDictionary(i => i.name, i => false); foreach (var layer in model.layers) { if (layer.flags.HasFlag(Layer.Flags.Preserve)) layerUsageByName[layer.name] = true; foreach (var i in layer.inputs) { layerUsageByName[i] = true; } } foreach (var o in model.outputs) { layerUsageByName[o] = true; } foreach (var mem in model.memories) { layerUsageByName[mem.output] = true; } return layerUsageByName.Where(keyValue => !keyValue.Value).Select(keyValue => keyValue.Key).ToArray(); } private static string[] FindBrokenLinks(Model model, HashSet links) { var allVariables = new HashSet(model.layers.Select(i => i.name)); var globalInputs = new HashSet(model.inputs.Select(i => i.name)); var memoryInputs = new HashSet(model.memories.Select(i => i.input)); allVariables.UnionWith(globalInputs); allVariables.UnionWith(memoryInputs); var brokenLinks = links; brokenLinks.ExceptWith(allVariables); return brokenLinks.ToArray(); } private static string[] FindBrokenLinks(Model model, string[] links) { return FindBrokenLinks(model, new HashSet(links)); } public static string[] FindBrokenLinks(Model model) { // check global outputs var linksToInspect = new HashSet(model.outputs); // and all layers foreach (var layer in model.layers) foreach (var i in layer.inputs) linksToInspect.Add(i); return FindBrokenLinks(model, linksToInspect); } public static string[] FindUnconnectedInputs(Model model) { var unconnected = model.inputs.ToDictionary(i => i.name, i => true); // check global outputs foreach (var o in model.outputs) unconnected.Remove(o); // and all layers foreach (var layer in model.layers) foreach (var i in layer.inputs) unconnected.Remove(i); return unconnected.Keys.ToArray(); } public static string[] FindLayerOutputs(Model model, string layerName) { var allVariables = model.layers.Where(x => x.inputs.Contains(layerName)).Select(x => x.name); var globalOutputs = model.outputs.Where(x => x == layerName); ; allVariables.Union(globalOutputs); return allVariables.ToArray(); } static public string[] FindUnconnectedOutputs(Model model) { return FindBrokenLinks(model, model.outputs.ToArray()); } public static bool IsLayerBroacastable(Layer layer) { return layer.type == Layer.Type.Add || layer.type == Layer.Type.Sub || layer.type == Layer.Type.Mul || layer.type == Layer.Type.Div || layer.type == Layer.Type.Pow || layer.type == Layer.Type.Min || layer.type == Layer.Type.Max || layer.type == Layer.Type.Mean || layer.type == Layer.Type.Greater || layer.type == Layer.Type.GreaterEqual || layer.type == Layer.Type.Less || layer.type == Layer.Type.LessEqual || layer.type == Layer.Type.Equal || layer.type == Layer.Type.LogicalOr || layer.type == Layer.Type.LogicalAnd || layer.type == Layer.Type.LogicalXor || layer.type == Layer.Type.Where || layer.type == Layer.Type.Concat; } public static bool IsLayerBroadcastSkippable(Layer layer) { if(layer.type == Layer.Type.ConstantOfShape) { // dynamic shape support if (layer.axis != 1) return true; else return false; } return false; } // Allow some unknown input dimension for shape inference pass // for now batch does not yield problematic shape inference, so allow for unkown batch public static bool IsInputShapeAcceptablyKnowForShapeInference(Model.Input input) // acceptable unknown shape : N { for (int i = 0; i < input.shape.Length; i++) { var x = input.shape[i]; if (x <= 0 && i != TensorShape.DataBatch) return false; } return true; } public static bool DoesTransposeChangeTensorLayout(TensorShape shape, int[] permutations) { var activeDimLayout = new List(); for (int i = 0; i < 8; i++) { if (shape[i] != 1) activeDimLayout.Add(i); } if (permutations.Length == 4) permutations = TensorExtensions.Get8DPermutationsForNHWCPermutationsAndShape(shape, permutations); var transposedLayout = TensorExtensions.Permute(new[] { 0, 1, 2, 3, 4, 5, 6, 7 }, permutations); var permutedShape = shape.Permute(permutations); var premutedActiveDimLayout = new List(); for (int i = 0; i < 8; i++) { if (permutedShape[i] != 1) premutedActiveDimLayout.Add(transposedLayout[i]); } return activeDimLayout.SequenceEqual(premutedActiveDimLayout); } } } // namespace Unity.Barracuda