using System;
using System.Collections.Generic;
using System.IO;
using UnityEngine;
using UnityEngine.Assertions;
using System.IO.Compression;
namespace Unity.Barracuda {
///
/// Test set loading utility
///
public class TestSet
{
private RawTestSet rawTestSet;
private JSONTestSet jsonTestSet;
///
/// Create with raw test set
///
/// raw test set
public TestSet(RawTestSet rawTestSet)
{
this.rawTestSet = rawTestSet;
}
///
/// Create with JSON test set
///
/// JSON test set
public TestSet(JSONTestSet jsonTestSet)
{
this.jsonTestSet = jsonTestSet;
}
///
/// Create `TestSet`
///
public TestSet()
{
}
///
/// Check if test set supports named tensors
///
/// `true` if named tensors are supported
public bool SupportsNames()
{
if (rawTestSet != null)
return false;
return true;
}
///
/// Get output tensor count
///
///
public int GetOutputCount()
{
if (rawTestSet != null)
return 1;
return jsonTestSet.outputs.Length;
}
///
/// Get output tensor data
///
/// tensor index
/// tensor data
public float[] GetOutputData(int idx = 0)
{
if (rawTestSet != null)
return rawTestSet.labels;
return jsonTestSet.outputs[idx].data;
}
///
/// Get output tensor name
///
/// tensor index
/// tensor name
public string GetOutputName(int idx = 0)
{
if (rawTestSet != null)
return null;
string name = jsonTestSet.outputs[idx].name;
return name.EndsWith(":0") ? name.Remove(name.Length - 2) : name;
}
///
/// Get input tensor count
///
///
public int GetInputCount()
{
if (rawTestSet != null)
return 1;
return jsonTestSet.inputs.Length;
}
///
/// Get input tensor name
///
/// input tensor index
/// tensor name
public string GetInputName(int idx = 0)
{
if (rawTestSet != null)
return "";
string name = jsonTestSet.inputs[idx].name;
return name.EndsWith(":0") ? name.Remove(name.Length - 2) : name;
}
///
/// Get input tensor data
///
/// input tensor index
/// tensor data
public float[] GetInputData(int idx = 0)
{
if (rawTestSet != null)
return rawTestSet.input;
return jsonTestSet.inputs[idx].data;
}
///
/// Get input shape
///
/// input tensor index
/// input shape
public TensorShape GetInputShape(int idx = 0)
{
if (rawTestSet != null)
return new TensorShape(1,rawTestSet.input.Length);
return new TensorShape(jsonTestSet.inputs[idx].shape.sequenceLength,
jsonTestSet.inputs[idx].shape.numberOfDirections,
jsonTestSet.inputs[idx].shape.batch,
jsonTestSet.inputs[idx].shape.extraDimension,
jsonTestSet.inputs[idx].shape.depth,
jsonTestSet.inputs[idx].shape.height,
jsonTestSet.inputs[idx].shape.width,
jsonTestSet.inputs[idx].shape.channels);
}
///
/// Get output tensor shape
///
/// output tensor index
/// tensor shape
public TensorShape GetOutputShape(int idx = 0)
{
if (rawTestSet != null)
return new TensorShape(1,rawTestSet.labels.Length);
return new TensorShape(jsonTestSet.outputs[idx].shape.sequenceLength,
jsonTestSet.outputs[idx].shape.numberOfDirections,
jsonTestSet.outputs[idx].shape.batch,
jsonTestSet.outputs[idx].shape.extraDimension,
jsonTestSet.outputs[idx].shape.depth,
jsonTestSet.outputs[idx].shape.height,
jsonTestSet.outputs[idx].shape.width,
jsonTestSet.outputs[idx].shape.channels);
}
///
/// Get inputs as `Tensor` dictionary
///
/// dictionary to store results
/// max batch count
/// start from batch
/// dictionary with input tensors
/// thrown if called on raw test set (only JSON test set is supported)
public Dictionary GetInputsAsTensorDictionary(Dictionary inputs = null, int batchCount = -1, int fromBatch = 0)
{
if (rawTestSet != null)
throw new Exception("GetInputsAsTensorDictionary is not supported for RAW test suites");
if (inputs == null)
inputs = new Dictionary();
for (var i = 0; i < GetInputCount(); i++)
inputs[GetInputName(i)] = GetInputAsTensor(i, batchCount, fromBatch);
return inputs;
}
///
/// Get outputs as `Tensor` dictionary
///
/// dictionary to store results
/// max batch count
/// start from batch
/// dictionary with input tensors
/// thrown if called on raw test set (only JSON test set is supported)
public Dictionary GetOutputsAsTensorDictionary(Dictionary outputs = null, int batchCount = -1, int fromBatch = 0)
{
if (rawTestSet != null)
throw new Exception("GetOutputsAsTensorDictionary is not supported for RAW test suites");
if (outputs == null)
outputs = new Dictionary();
for (var i = 0; i < GetOutputCount(); i++)
outputs[GetOutputName(i)] = GetOutputAsTensor(i, batchCount, fromBatch);
return outputs;
}
///
/// Get input as `Tensor`
///
/// input index
/// max batch count
/// start from batch
/// `Tensor`
/// thrown if called on raw test set (only JSON test set is supported)
public Tensor GetInputAsTensor(int idx = 0, int batchCount = -1, int fromBatch = 0)
{
if (rawTestSet != null)
throw new Exception("GetInputAsTensor is not supported for RAW test suites");
TensorShape shape = GetInputShape(idx);
Assert.IsTrue(shape.sequenceLength==1 && shape.numberOfDirections==1);
var array = GetInputData(idx);
var maxBatchCount = array.Length / shape.flatWidth;
fromBatch = Math.Min(fromBatch, maxBatchCount - 1);
if (batchCount < 0)
batchCount = maxBatchCount - fromBatch;
// pad data with 0s, if test-set doesn't have enough batches
var shapeArray = shape.ToArray();
shapeArray[TensorShape.DataBatch] = batchCount;
var tensorShape = new TensorShape(shapeArray);
var managedBufferStartIndex = fromBatch * tensorShape.flatWidth;
var count = Math.Min(batchCount, maxBatchCount - fromBatch) * tensorShape.flatWidth;
float[] dataToUpload = new float[tensorShape.length];
Array.Copy(array, managedBufferStartIndex, dataToUpload, 0, count);
var data = new ArrayTensorData(tensorShape.length);
data.Upload(dataToUpload, tensorShape, 0);
var res = new Tensor(tensorShape, data);
res.name = GetInputName(idx);
res.name = res.name.EndsWith(":0") ? res.name.Remove(res.name.Length - 2) : res.name;
return res;
}
///
/// Get output as `Tensor`
///
/// output index
/// max batch count
/// start from batch
/// `Tensor`
/// thrown if called on raw test set (only JSON test set is supported)
public Tensor GetOutputAsTensor(int idx = 0, int batchCount = -1, int fromBatch = 0)
{
if (rawTestSet != null)
throw new Exception("GetOutputAsTensor is not supported for RAW test suites");
TensorShape shape = GetOutputShape(idx);
Assert.IsTrue(shape.sequenceLength==1 && shape.numberOfDirections==1);
var barracudaArray = new BarracudaArrayFromManagedArray(GetOutputData(idx));
var maxBatchCount = barracudaArray.Length / shape.flatWidth;
fromBatch = Math.Min(fromBatch, maxBatchCount - 1);
if (batchCount < 0)
batchCount = maxBatchCount - fromBatch;
batchCount = Math.Min(batchCount, maxBatchCount - fromBatch);
var shapeArray = shape.ToArray();
shapeArray[TensorShape.DataBatch] = batchCount;
var tensorShape = new TensorShape(shapeArray);
var offset = fromBatch * tensorShape.flatWidth;
var res = new Tensor(tensorShape, new SharedArrayTensorData(barracudaArray, tensorShape, offset));
res.name = GetOutputName(idx);
res.name = res.name.EndsWith(":0") ? res.name.Remove(res.name.Length - 2) : res.name;
return res;
}
}
///
/// Raw test structure
///
public class RawTestSet
{
///
/// Input data
///
public float[] input;
///
/// Output data
///
public float[] labels;
}
///
/// JSON test structure
///
[Serializable]
public class JSONTestSet
{
///
/// Inputs
///
public JSONTensor[] inputs;
///
/// Outputs
///
public JSONTensor[] outputs;
}
///
/// JSON tensor shape
///
[Serializable]
public class JSONTensorShape
{
///
/// Sequence length
///
public int sequenceLength;
///
/// Number of directions
///
public int numberOfDirections;
///
/// Batch
///
public int batch;
///
/// Extra dimension
///
public int extraDimension;
///
/// Depth
///
public int depth;
///
/// Height
///
public int height;
///
/// Width
///
public int width;
///
/// Channels
///
public int channels;
}
///
/// JSON tensor
///
[Serializable]
public class JSONTensor
{
///
/// Name
///
public string name;
///
/// Shape
///
public JSONTensorShape shape;
///
/// Tensor type
///
public string type;
///
/// Tensor data
///
public float[] data;
}
///
/// Test set loader
///
public class TestSetLoader
{
///
/// Load test set from file
///
/// file name
/// `TestSet`
public static TestSet Load(string filename)
{
if (filename.ToLower().EndsWith(".raw"))
return LoadRaw(filename);
else if (filename.ToLower().EndsWith(".gz"))
return LoadGZ(filename);
return LoadJSON(filename);
}
///
/// Load GZ
///
/// file name
/// `TestSet`
public static TestSet LoadGZ(string filename)
{
var jsonFileName = filename.Substring(0, filename.Length - 3);
var sourceArchiveFileName = Path.Combine(Application.streamingAssetsPath, "TestSet", filename);
var destinationDirectoryName = sourceArchiveFileName.Substring(0, sourceArchiveFileName.Length - 3);
FileInfo fileToDecompress = new FileInfo(sourceArchiveFileName);
using (FileStream originalFileStream = fileToDecompress.OpenRead())
{
using (FileStream decompressedFileStream = File.Create(destinationDirectoryName))
{
using (GZipStream decompressionStream = new GZipStream(originalFileStream, CompressionMode.Decompress))
{
decompressionStream.CopyTo(decompressedFileStream);
}
}
}
return LoadJSON(jsonFileName);
}
///
/// Load JSON
///
/// file name
/// `TestSet`
public static TestSet LoadJSON(string filename)
{
string json = "";
if (filename.EndsWith(".json"))
json = File.ReadAllText(Path.Combine(Application.streamingAssetsPath, "TestSet", filename));
else
json = Resources.Load($"TestSet/{filename}").text;
TestSet result = new TestSet(JsonUtility.FromJson(json));
return result;
}
///
/// Load raw test set
///
/// file name
/// `TestSet`
public static TestSet LoadRaw(string filename)
{
string fullpath = Path.Combine(Application.streamingAssetsPath, "TestSet", filename);
using(BinaryReader file = Open(fullpath))
{
var rawTestSet = new RawTestSet();
rawTestSet.input = LoadFloatArray(file);
rawTestSet.labels = LoadFloatArray(file);
return new TestSet(rawTestSet);
}
}
///
/// Load image
///
/// file name
/// `Texture`
public static Texture LoadImage(string filename)
{
string fullpath = Path.Combine(Application.streamingAssetsPath, "TestSet", filename);
var bytes = File.ReadAllBytes(fullpath);
var tex = new Texture2D(2, 2);
ImageConversion.LoadImage(tex, bytes, false); // LoadImage will auto-resize the texture dimensions
tex.wrapMode = TextureWrapMode.Clamp;
return tex;
}
///
/// Load float array
///
/// binary file reader
/// float array
public static float[] LoadFloatArray(BinaryReader file)
{
Int64 dataLength = file.ReadInt64();
float[] array = new float[dataLength];
byte[] bytes = file.ReadBytes(Convert.ToInt32(dataLength * sizeof(float))); // @TODO: support larger than MaxInt32 data blocks
Buffer.BlockCopy(bytes, 0, array, 0, bytes.Length);
return array;
}
///
/// Open file with binary reader
///
/// file name
/// `BinaryReader`
static BinaryReader Open(string filename)
{
return new BinaryReader(new FileStream(filename, FileMode.Open, FileAccess.Read));
}
}
} // namespace Unity.Barracuda