Files
unity-application/Packages/com.unity.barracuda/Runtime/Core/Model.cs
2023-03-18 19:53:17 +00:00

1335 lines
30 KiB
C#

using System;
using System.Linq; // Select
using System.Collections.Generic;
using Unity.Barracuda.Compiler.Passes;
using UnityEngine.Assertions;
using UnityEditor;
namespace Unity.Barracuda {
/// <summary>
/// Barracuda Model Layer
/// </summary>
public class Layer
{
/// <summary>
/// Layer Type
/// </summary>
public enum Type
{
/// <summary>
/// No operation / identity layer
/// </summary>
Nop = 0,
/// <summary>
/// Dense layer
/// </summary>
Dense = 1,
/// <summary>
/// Matrix multiplication layer
/// </summary>
MatMul = 2,
/// <summary>
/// Rank-3 Dense Layer
/// </summary>
Dense3 = 3,
/// <summary>
/// 2D Convolution layer
/// </summary>
Conv2D = 20,
/// <summary>
/// Depthwise Convolution layer
/// </summary>
DepthwiseConv2D = 21,
/// <summary>
/// Transpose 2D Convolution layer
/// </summary>
Conv2DTrans = 22,
/// <summary>
/// Upsampling layer
/// </summary>
Upsample2D = 23,
/// <summary>
/// Max Pool layer
/// </summary>
MaxPool2D = 25,
/// <summary>
/// Average Pool layer
/// </summary>
AvgPool2D = 26,
/// <summary>
/// Global Max Pool layer
/// </summary>
GlobalMaxPool2D = 27,
/// <summary>
/// Global Average Pool layer
/// </summary>
GlobalAvgPool2D = 28,
/// <summary>
/// Border / Padding layer
/// </summary>
Border2D = 29,
/// <summary>
/// 3D Convolution layer
/// </summary>
Conv3D = 30,
/// <summary>
/// Transpose 3D Convolution layer (not yet implemented)
/// </summary>
Conv3DTrans = 32, // TODO: NOT IMPLEMENTED
/// <summary>
/// 3D Upsampling layer
/// </summary>
Upsample3D = 33,
/// <summary>
/// 3D Max Pool layer (not yet implemented)
/// </summary>
MaxPool3D = 35, // TODO: NOT IMPLEMENTED
/// <summary>
/// 3D Average Pool layer (not yet implemented)
/// </summary>
AvgPool3D = 36, // TODO: NOT IMPLEMENTED
/// <summary>
/// 3D Global Max Pool layer (not yet implemented)
/// </summary>
GlobalMaxPool3D = 37, // TODO: NOT IMPLEMENTED
/// <summary>
/// 3D Global Average Pool layer (not yet implemented)
/// </summary>
GlobalAvgPool3D = 38, // TODO: NOT IMPLEMENTED
/// <summary>
/// 3D Border / Padding layer
/// </summary>
Border3D = 39,
/// <summary>
/// Activation layer, see `Activation` enum for activation types
/// </summary>
Activation = 50,
/// <summary>
/// Scale + Bias layer
/// </summary>
ScaleBias = 51,
/// <summary>
/// Normalization layer
/// </summary>
Normalization = 52,
/// <summary>
/// LRN (Local Response Normalization) layer
/// </summary>
LRN = 53,
/// <summary>
/// Dropout layer (does nothing in inference)
/// </summary>
Dropout = 60,
/// <summary>
/// Random sampling from normal distribution layer
/// </summary>
RandomNormal = 64,
/// <summary>
/// Random sampling from uniform distribution layer
/// </summary>
RandomUniform = 65,
/// <summary>
/// Random sampling from multinomial distribution layer
/// </summary>
Multinomial = 66,
/// <summary>
/// OneHot layer
/// </summary>
OneHot = 67,
/// <summary>
/// TopK indices layer
/// </summary>
TopKIndices = 68,
/// <summary>
/// TopK values layer
/// </summary>
TopKValues = 69,
/// <summary>
/// NonZero layer
/// </summary>
NonZero = 70,
/// <summary>
/// Range layer
/// </summary>
Range = 71,
/// <summary>
/// RoiAlign layer
/// </summary>
RoiAlign = 72,
/// <summary>
/// Addition layer
/// </summary>
Add = 100,
/// <summary>
/// Subtraction layer
/// </summary>
Sub = 101,
/// <summary>
/// Multiplication layer
/// </summary>
Mul = 102,
/// <summary>
/// Division layer
/// </summary>
Div = 103,
/// <summary>
/// Power layer
/// </summary>
Pow = 104,
/// <summary>
/// Min layer
/// </summary>
Min = 110,
/// <summary>
/// Max layer
/// </summary>
Max = 111,
/// <summary>
/// Mean layer
/// </summary>
Mean = 112,
/// <summary>
/// Reduce L1 layer (not yet implemented)
/// </summary>
ReduceL1 = 120, // TODO: NOT IMPLEMENTED
/// <summary>
/// Reduce L2 layer (not yet implemented)
/// </summary>
ReduceL2 = 121, // TODO: NOT IMPLEMENTED
/// <summary>
/// Reduce LogSum layer (not yet implemented)
/// </summary>
ReduceLogSum = 122, // TODO: NOT IMPLEMENTED
/// <summary>
/// Reduce LogSumExp layer (not yet implemented)
/// </summary>
ReduceLogSumExp = 123, // TODO: NOT IMPLEMENTED
/// <summary>
/// Reduce with Max layer
/// </summary>
ReduceMax = 124,
/// <summary>
/// Reduce with Mean layer
/// </summary>
ReduceMean = 125,
/// <summary>
/// Reduce with Min layer
/// </summary>
ReduceMin = 126,
/// <summary>
/// Reduce with Prod layer
/// </summary>
ReduceProd = 127,
/// <summary>
/// Reduce with Sum layer
/// </summary>
ReduceSum = 128,
/// <summary>
/// Reduce with SumSquare layer (not yet implemented)
/// </summary>
ReduceSumSquare = 129, // TODO: NOT IMPLEMENTED
/// <summary>
/// Logic operation: Greater layer
/// </summary>
Greater = 140,
/// <summary>
/// Logic operation: GreaterEqual layer
/// </summary>
GreaterEqual = 141,
/// <summary>
/// Logic operation: Less layer
/// </summary>
Less = 142,
/// <summary>
/// Logic operation: LessEqual layer
/// </summary>
LessEqual = 143,
/// <summary>
/// Logic operation: Equal layer
/// </summary>
Equal = 144,
/// <summary>
/// Logic operation: LogicalOr layer
/// </summary>
LogicalOr = 145,
/// <summary>
/// Logic operation: LogicalAnd layer
/// </summary>
LogicalAnd = 146,
/// <summary>
/// Logic operation: LogicalNot layer
/// </summary>
LogicalNot = 147,
/// <summary>
/// Logic operation: LogicalXor layer
/// </summary>
LogicalXor = 148,
/// <summary>
/// Logic operation: Where layer
/// </summary>
Where = 149,
/// <summary>
/// Logic operation: Sign layer
/// </summary>
Sign = 150,
/// <summary>
/// Generic Pad layer (not fully supported)
/// </summary>
Pad = 159, // TODO: NOT IMPLEMENTED
/// <summary>
/// Reflection padding layer
/// </summary>
Pad2DReflect = 160,
/// <summary>
/// Symmetric padding layer
/// </summary>
Pad2DSymmetric = 161,
/// <summary>
/// Edge padding layer
/// </summary>
Pad2DEdge = 162,
/// <summary>
/// ArgMax layer
/// </summary>
ArgMax = 163,
/// <summary>
/// ArgMin layer
/// </summary>
ArgMin = 164,
/// <summary>
/// ConstantOfShape layer
/// </summary>
ConstantOfShape = 199,
/// <summary>
/// Flatten layer
/// </summary>
Flatten = 200,
/// <summary>
/// Reshape layer
/// </summary>
Reshape = 201,
/// <summary>
/// Transpose layer
/// </summary>
Transpose = 202,
/// <summary>
/// Squeeze layer (not fully supported)
/// </summary>
Squeeze = 203, // TODO: NOT IMPLEMENTED
/// <summary>
/// Unsqueeze layer (not fully supported)
/// </summary>
Unsqueeze = 204, // TODO: NOT IMPLEMENTED
/// <summary>
/// Gather layer
/// </summary>
Gather = 205,
/// <summary>
/// Depth to space layer
/// </summary>
DepthToSpace = 206,
/// <summary>
/// Space to depth layer
/// </summary>
SpaceToDepth = 207,
/// <summary>
/// Expand layer
/// </summary>
Expand = 208,
/// <summary>
/// 2D Resample layer
/// </summary>
Resample2D = 209,
/// <summary>
/// Concat layer
/// </summary>
Concat = 210,
/// <summary>
/// Strided slice layer
/// </summary>
StridedSlice = 211,
/// <summary>
/// Tile layer
/// </summary>
Tile = 212,
/// <summary>
/// Shape layer
/// </summary>
Shape = 213,
/// <summary>
/// Non max suppression layer
/// </summary>
NonMaxSuppression = 214,
/// <summary>
/// LSTM
/// </summary>
LSTM = 215,
/// <summary>
/// ScatterND
/// </summary>
ScatterND = 216,
/// <summary>
/// Constant load layer (for internal use)
/// </summary>
Load = 255
}
//Keep in sync with Tensor.cginc ACTIVATION defines and IsActivationFusable() methods in ModelBuilder.cs and FuseActivationsPass.cs
/// <summary>
/// Fused activations enum
/// </summary>
public enum FusedActivation
{
/// <summary>
/// None
/// </summary>
None = Activation.None,
/// <summary>
/// Relu
/// </summary>
Relu = Activation.Relu,
/// <summary>
/// Tanh
/// </summary>
Tanh = Activation.Tanh,
/// <summary>
/// Softplus
/// </summary>
Softplus = Activation.Softplus,
/// <summary>
/// Sigmoid
/// </summary>
Sigmoid = Activation.Sigmoid,
/// <summary>
/// Relu6
/// </summary>
Relu6 = Activation.Relu6,
/// <summary>
/// Swish
/// </summary>
Swish = Activation.Swish,
/// <summary>
/// Neg
/// </summary>
Neg = Activation.Neg,
/// <summary>
/// Sqrt
/// </summary>
Sqrt = Activation.Sqrt,
/// <summary>
/// Exp
/// </summary>
Exp = Activation.Exp,
/// <summary>
/// Log
/// </summary>
Log = Activation.Log,
/// <summary>
/// Acos
/// </summary>
Acos = Activation.Acos,
/// <summary>
/// Acosh
/// </summary>
Acosh = Activation.Acosh,
/// <summary>
/// Asin
/// </summary>
Asin = Activation.Asin,
/// <summary>
/// Asinh
/// </summary>
Asinh = Activation.Asinh,
/// <summary>
/// Atan
/// </summary>
Atan = Activation.Atan,
/// <summary>
/// Atanh
/// </summary>
Atanh = Activation.Atanh,
/// <summary>
/// Cos
/// </summary>
Cos = Activation.Cos,
/// <summary>
/// Cosh
/// </summary>
Cosh = Activation.Cosh,
/// <summary>
/// Sin
/// </summary>
Sin = Activation.Sin,
/// <summary>
/// Sinh
/// </summary>
Sinh = Activation.Sinh,
/// <summary>
/// Tan
/// </summary>
Tan = Activation.Tan,
/// <summary>
/// Erf
/// </summary>
Erf = Activation.Erf
}
/// <summary>
/// Activation enum
/// </summary>
public enum Activation
{
/// <summary>
/// None
/// </summary>
None = 0,
/// <summary>
/// Relu
/// </summary>
Relu = 1,
/// <summary>
/// Softmax
/// </summary>
Softmax = 2,
/// <summary>
/// Tanh
/// </summary>
Tanh = 3,
/// <summary>
/// Sigmoid
/// </summary>
Sigmoid = 4,
/// <summary>
/// Elu
/// </summary>
Elu = 5,
/// <summary>
/// Relu6
/// </summary>
Relu6 = 6,
/// <summary>
/// LeakyRelu
/// </summary>
LeakyRelu = 7,
/// <summary>
/// Selu
/// </summary>
Selu = 8,
/// <summary>
/// Swish
/// </summary>
Swish = 9,
/// <summary>
/// LogSoftmax
/// </summary>
LogSoftmax = 10,
/// <summary>
/// Softplus
/// </summary>
Softplus = 11,
/// <summary>
/// Softsign (not yet implemented)
/// </summary>
Softsign = 12, // TODO: NOT IMPLEMENTED
/// <summary>
/// PRelu
/// </summary>
PRelu = 13,
/// <summary>
/// Hardmax (not yet implemented)
/// </summary>
Hardmax = 20, // TODO: NOT IMPLEMENTED
/// <summary>
/// HardSigmoid
/// </summary>
HardSigmoid = 21,
/// <summary>
/// Abs
/// </summary>
Abs = 100,
/// <summary>
/// Neg
/// </summary>
Neg = 101,
/// <summary>
/// Ceil
/// </summary>
Ceil = 102,
/// <summary>
/// Clip
/// </summary>
Clip = 103,
/// <summary>
/// Floor
/// </summary>
Floor = 104,
/// <summary>
/// Round
/// </summary>
Round = 105,
/// <summary>
/// Reciprocal
/// </summary>
Reciprocal = 110,
/// <summary>
/// Sqrt
/// </summary>
Sqrt = 111,
/// <summary>
/// Pow
/// </summary>
Pow = 112,
/// <summary>
/// Exp
/// </summary>
Exp = 113,
/// <summary>
/// Log
/// </summary>
Log = 114,
/// <summary>
/// Acos
/// </summary>
Acos = 200,
/// <summary>
/// Acosh
/// </summary>
Acosh = 201,
/// <summary>
/// Asin
/// </summary>
Asin = 202,
/// <summary>
/// Asinh
/// </summary>
Asinh = 203,
/// <summary>
/// Atan
/// </summary>
Atan = 204,
/// <summary>
/// Atanh
/// </summary>
Atanh = 205,
/// <summary>
/// Cos
/// </summary>
Cos = 206,
/// <summary>
/// Cosh
/// </summary>
Cosh = 207,
/// <summary>
/// Sin
/// </summary>
Sin = 208,
/// <summary>
/// Sinh
/// </summary>
Sinh = 209,
/// <summary>
/// Tan
/// </summary>
Tan = 210,
/// <summary>
/// Erf
/// </summary>
Erf = 211
}
/// <summary>
/// Auto padding enum
/// </summary>
public enum AutoPad
{
/// <summary>
/// NotSet
/// </summary>
NotSet = 1,
/// <summary>
/// Valid
/// </summary>
Valid = 0,
/// <summary>
/// Same upper
/// </summary>
SameUpper = -1,
/// <summary>
/// Same lower
/// </summary>
SameLower = -2,
}
public enum PadMode
{
Constant = 0,
Reflect = 1,
Edge = 2,
Symetric = 3,
}
/// <summary>
/// Depth to space mode enum
/// </summary>
public enum DepthToSpaceMode
{
/// <summary>
/// DCR (Depth Column Row)
/// </summary>
DCR,
/// <summary>
/// CRD (Column Row Depth)
/// </summary>
CRD
}
/// <summary>
/// ScatterND reduction mode
/// </summary>
public enum ScatterNDReductionMode
{
/// <summary>
/// None
/// </summary>
None = 0,
/// <summary>
/// Add
/// </summary>
Add = 1,
/// <summary>
/// Multiply
/// </summary>
Mul = 2,
}
/// <summary>
/// Layer param data structure
/// </summary>
public struct DataSet
{
/// <summary>
/// Name
/// </summary>
public string name;
/// <summary>
/// Shape
/// </summary>
public TensorShape shape;
/// <summary>
/// Offset from start
/// </summary>
public Int64 offset;
/// <summary>
/// Item size in bytes
/// </summary>
public Int32 itemSizeInBytes;
/// <summary>
/// Dataset length
/// </summary>
public Int32 length;
}
/// <summary>
/// Layer preservation flags
/// </summary>
[Flags]
public enum Flags
{
/// <summary>
/// No flags defined
/// </summary>
None = 0,
/// <summary>
/// Preserve the layer (e.g. don't remove it in a model pass)
/// </summary>
Preserve = 1 << 1,
}
/// <summary>
/// Layer name
/// </summary>
public string name;
/// <summary>
/// Layer type
/// </summary>
public Type type;
/// <summary>
/// Layer flags (not serialized) - used for conversion
/// </summary>
[NonSerialized]
public Flags flags;
/// <summary>
/// Layer activation type
/// </summary>
public Activation activation;
/// <summary>
/// Padding shape
/// </summary>
public Int32[] pad;
/// <summary>
/// Stride
/// </summary>
public Int32[] stride;
/// <summary>
/// Pooling
/// </summary>
public Int32[] pool;
/// <summary>
/// Axis
/// </summary>
public Int32 axis;
/// <summary>
/// Alpha
/// </summary>
public float alpha;
/// <summary>
/// Beta
/// </summary>
public float beta;
/// <summary>
/// Input (layer) names
/// </summary>
public string[] inputs;
/// <summary>
/// Output (layer) names (not serialized) - used for conversion
/// </summary>
[NonSerialized]
public string[] outputs;
/// <summary>
/// Axes (not serialized) - used for conversion
/// </summary>
[NonSerialized]
public Int32[] axes;
/// <summary>
/// Datasets bound to layer
/// </summary>
public DataSet[] datasets;
/// <summary>
/// Flat weights array (for actual shape see `datasets`)
/// </summary>
public BarracudaArray weights;
private Layer(string layerName)
{
name = layerName;
type = Type.Nop;
activation = Activation.None;
pad = new int[0];
stride = new int[0];
pool = new int[0];
axis = -1;
alpha = 1.0f;
beta = 0.0f;
inputs = new string[0];
datasets = new DataSet[0];
weights = new BarracudaArray(0);//TODO fp16?
}
/// <summary>
/// Constructs Layer
/// </summary>
/// <param name="layerName">layer name</param>
/// <param name="layerType">layer type</param>
/// <param name="activationType">layer activation type</param>
public Layer(string layerName, Type layerType, Activation activationType = Activation.None) : this(layerName)
{
type = layerType;
activation = activationType;
}
/// <summary>
/// Constructs Activation Layer
/// </summary>
/// <param name="layerName">layer name</param>
/// <param name="activationType">layer activation type</param>
public Layer(string layerName, Activation activationType) : this(layerName)
{
type = Type.Activation;
activation = activationType;
}
/// <summary>
/// Layer summary string
/// </summary>
/// <returns>layer summary string</returns>
public override string ToString()
{
return ($"name:{name}, activation:{activation}, inputs:[{string.Join(",", inputs)}], " +
$"pad:[{string.Join(",", pad)}], stride:[{string.Join(",", stride)}], pool:[{string.Join(",", pool)}], " +
$"alpha:{alpha}, beta:{beta}, axis:{axis}, " +
$"weights:[{string.Join(", ", datasets.Select(x => $"{x.name} {x.shape}"))}]".Replace(name+"/","").Replace(name+" ","")).
Replace("activation:None, ", "").Replace("inputs:[], ", "").Replace("pad:[], ", "").
Replace("stride:[], ", "").Replace("stride:[1,1], ", "").Replace("pool:[], ", "").
Replace("alpha:1, ", "").Replace("beta:0, ", "").Replace("axis:-1, ", "").
Replace("weights:[]", "");
}
/// <summary>
/// Converts DataSet to Tensor
/// </summary>
/// <param name="index">dataset index</param>
/// <returns>Tensor</returns>
public Tensor DataSetToTensor(int index)
{
Assert.IsTrue(index < datasets.Length);
var ds = datasets[index];
return new Tensor(ds.shape, new SharedArrayTensorData(weights, ds.shape, (int)ds.offset), ds.name);
}
/// <summary>
/// Converts Tensor to DataSet
/// </summary>
/// <param name="X">input `Tensor`</param>
/// <param name="index">dataset index</param>
public void ApplyTensorToDataSet(Tensor X, int index)
{
Assert.IsTrue(index < datasets.Length);
var ds = datasets[index];
ds.shape = X.shape;
BarracudaArray.Copy(X.ToReadOnlyArray(), 0, weights, ds.offset, ds.shape.length);
datasets[index] = ds;
}
}
/// <summary>
/// Neural Net Model data structure
/// </summary>
public class Model
{
/// <summary>
/// Model version, incremented with each data structure change
/// </summary>
public const int Version = 20;
internal const int LastVersionWithout8DSupport = 16;
public const int LastVersionWithoutWeightsAlignmentSupport = 18;
internal const int WeightsAlignment = 16;
/// <summary>
/// Input data structure
/// </summary>
public struct Input
{
/// <summary>
/// Name
/// </summary>
public string name;
/// <summary>
/// Shape as `int` array
/// </summary>
public Int32[] shape; // input shape can contain -1 for unspecified dimensions
/// <summary>
/// Input rank
/// </summary>
public int rank;
/// <summary>
/// Creates input structure with specified name
/// </summary>
/// <param name="name">name</param>
/// <returns>Input structure</returns>
public Input WithName(string name)
{
return new Input {name = name, shape = shape};
}
}
/// <summary>
/// Memory data structure. Used by recurrent models to store information about recurrent inputs/outputs
/// </summary>
public struct Memory
{
/// <summary>
/// Shape
/// </summary>
public TensorShape shape;
/// <summary>
/// Input name
/// </summary>
public string input;
/// <summary>
/// Output name
/// </summary>
public string output;
}
/// <summary>
/// Model layout
/// </summary>
public string layout = String.Empty;
/// <summary>
/// All model inputs
/// </summary>
public List<Input> inputs = new List<Input>();
/// <summary>
/// All model outputs
/// </summary>
public List<string> outputs = new List<string>();
/// <summary>
/// All model memories
/// </summary>
public List<Memory> memories = new List<Memory>();
/// <summary>
/// All model layers
/// </summary>
public List<Layer> layers = new List<Layer>();
#region Importer info
/// <summary>
/// Model source metadata string
/// </summary>
public string IrSource = "Script";
/// <summary>
/// Model ONNX version metadata string
/// </summary>
public string IrVersion = "NA";
/// <summary>
/// Model producer metadata string
/// </summary>
public string ProducerName = "Script";
/// <summary>
/// Model import warnings
/// </summary>
public List<ImporterWarning> Warnings { get; } = new List<ImporterWarning>();
/// <summary>
/// Importer warning data structure
/// </summary>
public class ImporterWarning
{
/// <summary>
/// Message
/// </summary>
public string Message { get; }
/// <summary>
/// Layer name
/// </summary>
public string LayerName { get; }
/// <summary>
/// Constructs ImporterWarning
/// </summary>
/// <param name="layer">layer name</param>
/// <param name="msg">message</param>
public ImporterWarning(string layer, string msg)
{
Message = msg;
LayerName = layer;
}
}
/// <summary>
/// Metadata properties associated with the model
/// </summary>
public Dictionary<string, string> Metadata { get; private set; } = new Dictionary<string, string>();
#endregion
/// <summary>
/// Build shallow copy of the model
/// </summary>
/// <returns>shallow copy of the model</returns>
public Model ShallowCopy()
{
var model = new Model();
model.inputs.AddRange(inputs);
model.outputs.AddRange(outputs);
model.memories.AddRange(memories);
model.layers.AddRange(layers);
model.IrSource = IrSource;
model.IrVersion = IrVersion;
model.ProducerName = ProducerName;
model.Warnings.AddRange(Warnings);
model.Metadata = new Dictionary<string, string>(Metadata);
return model;
}
/// <summary>
/// Model summary string
/// </summary>
/// <returns>Model summary string</returns>
public override string ToString()
{
// weights are not loaded for UI, recompute size
var totalUniqueWeights = 0;
for (var l = 0; l < layers.Count; ++l)
for (var d = 0; d < layers[l].datasets.Length; ++d)
totalUniqueWeights += layers[l].datasets[d].length;
return $"inputs: [{string.Join(", ", inputs.Select(i => $"{i.name} ({string.Join(",", i.shape)})"))}], " +
$"memories: [{string.Join(", ", memories.Select(m => $"{m.input} {m.shape} {m.output}"))}], " +
$"outputs: [{string.Join(", ", outputs)}] " +
$"\n{layers.Count} layers, {totalUniqueWeights:n0} weights: \n{string.Join("\n", layers.Select(i => $"{i.type} ({i})"))}";
}
/// <summary>
/// Convert in place all model weights to given data type
/// </summary>
/// <param name="type">target type for moodel weights</param>
internal void ConvertWeights(DataType type)
{
foreach (var layer in layers)
{
if (layer.weights != null && layer.weights.Type != type)
{
var sourceWeights = layer.weights;
var targetWeights = new BarracudaArray(layer.weights.Length, type);
BarracudaArray.Copy(sourceWeights, targetWeights);
layer.weights = targetWeights;
}
}
}
}
/// <summary>
/// Model metadata extensions
/// </summary>
public static class ModelMetadataExtensions
{
/// <summary>
/// Get model tensor by name
/// </summary>
/// <param name="model">Model</param>
/// <param name="name">Tensor name</param>
/// <returns>Tensor</returns>
static public Tensor GetTensorByName(this Model model, string name)
{
foreach (var l in model.layers)
foreach (var ds in l.datasets)
if (ds.name == name)
return new Tensor(ds.shape,
new SharedArrayTensorData(l.weights, ds.shape, (int)ds.offset), ds.name);
return null;
}
/// <summary>
/// Get model tensor shape by name
/// </summary>
/// <param name="model">Model</param>
/// <param name="name">Tensor name</param>
/// <returns>Tensor shape</returns>
/// <exception cref="KeyNotFoundException"></exception>
static public TensorShape? GetShapeByName(this Model model, string name)
{
foreach (var i in model.inputs)
if (i.name == name)
return new TensorShape(i.shape);
TensorShape shape;
if (ModelAnalyzer.TryGetOutputTensorShape(model, name, out shape))
return shape;
foreach (var l in model.layers)
foreach (var ds in l.datasets)
if (ds.name == name)
return ds.shape;
foreach (var mem in model.memories)
{
if (mem.input == name || mem.output == name)
return mem.shape;
}
throw new System.Collections.Generic.KeyNotFoundException("Shape " + name + " not found!");
}
/// <summary>
/// Get count of layers that directly depend on specified input
/// </summary>
/// <param name="model">Model</param>
/// <param name="name">input name</param>
/// <returns>count of layers that directly depend on specified input</returns>
static public int GetDownStreamLayersCount(this Model model, string name)
{
return model.layers.Count(x => x.inputs.Contains(name));
}
}
} // namespace Unity.Barracuda