using System; using System.Collections.Generic; using System.IO; using UnityEngine; using UnityEngine.Assertions; using System.IO.Compression; namespace Unity.Barracuda { /// /// Test set loading utility /// public class TestSet { private RawTestSet rawTestSet; private JSONTestSet jsonTestSet; /// /// Create with raw test set /// /// raw test set public TestSet(RawTestSet rawTestSet) { this.rawTestSet = rawTestSet; } /// /// Create with JSON test set /// /// JSON test set public TestSet(JSONTestSet jsonTestSet) { this.jsonTestSet = jsonTestSet; } /// /// Create `TestSet` /// public TestSet() { } /// /// Check if test set supports named tensors /// /// `true` if named tensors are supported public bool SupportsNames() { if (rawTestSet != null) return false; return true; } /// /// Get output tensor count /// /// public int GetOutputCount() { if (rawTestSet != null) return 1; return jsonTestSet.outputs.Length; } /// /// Get output tensor data /// /// tensor index /// tensor data public float[] GetOutputData(int idx = 0) { if (rawTestSet != null) return rawTestSet.labels; return jsonTestSet.outputs[idx].data; } /// /// Get output tensor name /// /// tensor index /// tensor name public string GetOutputName(int idx = 0) { if (rawTestSet != null) return null; string name = jsonTestSet.outputs[idx].name; return name.EndsWith(":0") ? name.Remove(name.Length - 2) : name; } /// /// Get input tensor count /// /// public int GetInputCount() { if (rawTestSet != null) return 1; return jsonTestSet.inputs.Length; } /// /// Get input tensor name /// /// input tensor index /// tensor name public string GetInputName(int idx = 0) { if (rawTestSet != null) return ""; string name = jsonTestSet.inputs[idx].name; return name.EndsWith(":0") ? name.Remove(name.Length - 2) : name; } /// /// Get input tensor data /// /// input tensor index /// tensor data public float[] GetInputData(int idx = 0) { if (rawTestSet != null) return rawTestSet.input; return jsonTestSet.inputs[idx].data; } /// /// Get input shape /// /// input tensor index /// input shape public TensorShape GetInputShape(int idx = 0) { if (rawTestSet != null) return new TensorShape(1,rawTestSet.input.Length); return new TensorShape(jsonTestSet.inputs[idx].shape.sequenceLength, jsonTestSet.inputs[idx].shape.numberOfDirections, jsonTestSet.inputs[idx].shape.batch, jsonTestSet.inputs[idx].shape.extraDimension, jsonTestSet.inputs[idx].shape.depth, jsonTestSet.inputs[idx].shape.height, jsonTestSet.inputs[idx].shape.width, jsonTestSet.inputs[idx].shape.channels); } /// /// Get output tensor shape /// /// output tensor index /// tensor shape public TensorShape GetOutputShape(int idx = 0) { if (rawTestSet != null) return new TensorShape(1,rawTestSet.labels.Length); return new TensorShape(jsonTestSet.outputs[idx].shape.sequenceLength, jsonTestSet.outputs[idx].shape.numberOfDirections, jsonTestSet.outputs[idx].shape.batch, jsonTestSet.outputs[idx].shape.extraDimension, jsonTestSet.outputs[idx].shape.depth, jsonTestSet.outputs[idx].shape.height, jsonTestSet.outputs[idx].shape.width, jsonTestSet.outputs[idx].shape.channels); } /// /// Get inputs as `Tensor` dictionary /// /// dictionary to store results /// max batch count /// start from batch /// dictionary with input tensors /// thrown if called on raw test set (only JSON test set is supported) public Dictionary GetInputsAsTensorDictionary(Dictionary inputs = null, int batchCount = -1, int fromBatch = 0) { if (rawTestSet != null) throw new Exception("GetInputsAsTensorDictionary is not supported for RAW test suites"); if (inputs == null) inputs = new Dictionary(); for (var i = 0; i < GetInputCount(); i++) inputs[GetInputName(i)] = GetInputAsTensor(i, batchCount, fromBatch); return inputs; } /// /// Get outputs as `Tensor` dictionary /// /// dictionary to store results /// max batch count /// start from batch /// dictionary with input tensors /// thrown if called on raw test set (only JSON test set is supported) public Dictionary GetOutputsAsTensorDictionary(Dictionary outputs = null, int batchCount = -1, int fromBatch = 0) { if (rawTestSet != null) throw new Exception("GetOutputsAsTensorDictionary is not supported for RAW test suites"); if (outputs == null) outputs = new Dictionary(); for (var i = 0; i < GetOutputCount(); i++) outputs[GetOutputName(i)] = GetOutputAsTensor(i, batchCount, fromBatch); return outputs; } /// /// Get input as `Tensor` /// /// input index /// max batch count /// start from batch /// `Tensor` /// thrown if called on raw test set (only JSON test set is supported) public Tensor GetInputAsTensor(int idx = 0, int batchCount = -1, int fromBatch = 0) { if (rawTestSet != null) throw new Exception("GetInputAsTensor is not supported for RAW test suites"); TensorShape shape = GetInputShape(idx); Assert.IsTrue(shape.sequenceLength==1 && shape.numberOfDirections==1); var array = GetInputData(idx); var maxBatchCount = array.Length / shape.flatWidth; fromBatch = Math.Min(fromBatch, maxBatchCount - 1); if (batchCount < 0) batchCount = maxBatchCount - fromBatch; // pad data with 0s, if test-set doesn't have enough batches var shapeArray = shape.ToArray(); shapeArray[TensorShape.DataBatch] = batchCount; var tensorShape = new TensorShape(shapeArray); var managedBufferStartIndex = fromBatch * tensorShape.flatWidth; var count = Math.Min(batchCount, maxBatchCount - fromBatch) * tensorShape.flatWidth; float[] dataToUpload = new float[tensorShape.length]; Array.Copy(array, managedBufferStartIndex, dataToUpload, 0, count); var data = new ArrayTensorData(tensorShape.length); data.Upload(dataToUpload, tensorShape, 0); var res = new Tensor(tensorShape, data); res.name = GetInputName(idx); res.name = res.name.EndsWith(":0") ? res.name.Remove(res.name.Length - 2) : res.name; return res; } /// /// Get output as `Tensor` /// /// output index /// max batch count /// start from batch /// `Tensor` /// thrown if called on raw test set (only JSON test set is supported) public Tensor GetOutputAsTensor(int idx = 0, int batchCount = -1, int fromBatch = 0) { if (rawTestSet != null) throw new Exception("GetOutputAsTensor is not supported for RAW test suites"); TensorShape shape = GetOutputShape(idx); Assert.IsTrue(shape.sequenceLength==1 && shape.numberOfDirections==1); var barracudaArray = new BarracudaArrayFromManagedArray(GetOutputData(idx)); var maxBatchCount = barracudaArray.Length / shape.flatWidth; fromBatch = Math.Min(fromBatch, maxBatchCount - 1); if (batchCount < 0) batchCount = maxBatchCount - fromBatch; batchCount = Math.Min(batchCount, maxBatchCount - fromBatch); var shapeArray = shape.ToArray(); shapeArray[TensorShape.DataBatch] = batchCount; var tensorShape = new TensorShape(shapeArray); var offset = fromBatch * tensorShape.flatWidth; var res = new Tensor(tensorShape, new SharedArrayTensorData(barracudaArray, tensorShape, offset)); res.name = GetOutputName(idx); res.name = res.name.EndsWith(":0") ? res.name.Remove(res.name.Length - 2) : res.name; return res; } } /// /// Raw test structure /// public class RawTestSet { /// /// Input data /// public float[] input; /// /// Output data /// public float[] labels; } /// /// JSON test structure /// [Serializable] public class JSONTestSet { /// /// Inputs /// public JSONTensor[] inputs; /// /// Outputs /// public JSONTensor[] outputs; } /// /// JSON tensor shape /// [Serializable] public class JSONTensorShape { /// /// Sequence length /// public int sequenceLength; /// /// Number of directions /// public int numberOfDirections; /// /// Batch /// public int batch; /// /// Extra dimension /// public int extraDimension; /// /// Depth /// public int depth; /// /// Height /// public int height; /// /// Width /// public int width; /// /// Channels /// public int channels; } /// /// JSON tensor /// [Serializable] public class JSONTensor { /// /// Name /// public string name; /// /// Shape /// public JSONTensorShape shape; /// /// Tensor type /// public string type; /// /// Tensor data /// public float[] data; } /// /// Test set loader /// public class TestSetLoader { /// /// Load test set from file /// /// file name /// `TestSet` public static TestSet Load(string filename) { if (filename.ToLower().EndsWith(".raw")) return LoadRaw(filename); else if (filename.ToLower().EndsWith(".gz")) return LoadGZ(filename); return LoadJSON(filename); } /// /// Load GZ /// /// file name /// `TestSet` public static TestSet LoadGZ(string filename) { var jsonFileName = filename.Substring(0, filename.Length - 3); var sourceArchiveFileName = Path.Combine(Application.streamingAssetsPath, "TestSet", filename); var destinationDirectoryName = sourceArchiveFileName.Substring(0, sourceArchiveFileName.Length - 3); FileInfo fileToDecompress = new FileInfo(sourceArchiveFileName); using (FileStream originalFileStream = fileToDecompress.OpenRead()) { using (FileStream decompressedFileStream = File.Create(destinationDirectoryName)) { using (GZipStream decompressionStream = new GZipStream(originalFileStream, CompressionMode.Decompress)) { decompressionStream.CopyTo(decompressedFileStream); } } } return LoadJSON(jsonFileName); } /// /// Load JSON /// /// file name /// `TestSet` public static TestSet LoadJSON(string filename) { string json = ""; if (filename.EndsWith(".json")) json = File.ReadAllText(Path.Combine(Application.streamingAssetsPath, "TestSet", filename)); else json = Resources.Load($"TestSet/{filename}").text; TestSet result = new TestSet(JsonUtility.FromJson(json)); return result; } /// /// Load raw test set /// /// file name /// `TestSet` public static TestSet LoadRaw(string filename) { string fullpath = Path.Combine(Application.streamingAssetsPath, "TestSet", filename); using(BinaryReader file = Open(fullpath)) { var rawTestSet = new RawTestSet(); rawTestSet.input = LoadFloatArray(file); rawTestSet.labels = LoadFloatArray(file); return new TestSet(rawTestSet); } } /// /// Load image /// /// file name /// `Texture` public static Texture LoadImage(string filename) { string fullpath = Path.Combine(Application.streamingAssetsPath, "TestSet", filename); var bytes = File.ReadAllBytes(fullpath); var tex = new Texture2D(2, 2); ImageConversion.LoadImage(tex, bytes, false); // LoadImage will auto-resize the texture dimensions tex.wrapMode = TextureWrapMode.Clamp; return tex; } /// /// Load float array /// /// binary file reader /// float array public static float[] LoadFloatArray(BinaryReader file) { Int64 dataLength = file.ReadInt64(); float[] array = new float[dataLength]; byte[] bytes = file.ReadBytes(Convert.ToInt32(dataLength * sizeof(float))); // @TODO: support larger than MaxInt32 data blocks Buffer.BlockCopy(bytes, 0, array, 0, bytes.Length); return array; } /// /// Open file with binary reader /// /// file name /// `BinaryReader` static BinaryReader Open(string filename) { return new BinaryReader(new FileStream(filename, FileMode.Open, FileAccess.Read)); } } } // namespace Unity.Barracuda