using System;
using System.Collections.Generic;
using UnityEngine;
using Unity.Jobs;
namespace Unity.Barracuda
{
///
/// BLAS plugin interface, allows to supply platform specific implementation of matrix multiplication
///
public interface BLASPlugin
{
///
/// Query if BLAS implementation is coming from platform's native library
///
/// `true` if BLAS implementation is coming from platform's native library
bool IsNative();
///
/// Query if current platform is supported by the BLAS plugin
///
/// `true` if plugin supports current platform
bool IsCurrentPlatformSupported();
///
/// Perform matrix multiplication C = A x B + C
///
/// pointer to the matrix A
/// matrix A row count
/// matrix A column count
/// pointer to the matrix B
/// matrix B row count
/// matrix B column count
/// pointer to the matrix C
/// matrix C row count
/// matrix C column count
/// inner loop block size (if applicable) bs x bs
/// matrix A data is in transposed layout
/// matrix B data is in transposed layout
unsafe void SGEMM(float* Ap, int AM, int AN,
float* Bp, int BM, int BN,
float* Cp, int CM, int CN, int bs,
bool transposeA = false, bool transposeB = false);
///
/// Launches matrix multiplication C = A x B + C in async-manner
///
/// input data dependency job handle
/// pointer to the matrix A
/// matrix A row count
/// matrix A column count
/// pointer to the matrix B
/// matrix B row count
/// matrix B column count
/// pointer to the matrix C
/// matrix C row count
/// matrix C column count
/// inner loop block size (if applicable) bs x bs
/// matrix A data is in transposed layout
/// matrix B data is in transposed layout
/// job handle
unsafe JobHandle ScheduleSGEMM(JobHandle dependsOn,
float* Ap, int AM, int AN,
float* Bp, int BM, int BN,
float* Cp, int CM, int CN, int bs,
bool transposeA = false, bool transposeB = false);
}
internal class BLASPluginFactory
{
public static BLASPlugin CreateBLASPlugin()
{
BLASPlugin blas = null;
// TODO make plugins discoverable via custom attributes
Stack plugins = new Stack();
plugins.Push(typeof(CSharpBLAS).FullName);
plugins.Push("Unity.Barracuda.BurstBLAS");
if (Application.platform == RuntimePlatform.IPhonePlayer)
plugins.Push("Unity.Barracuda.iOSBLAS");
else if (Application.platform == RuntimePlatform.OSXPlayer || Application.platform == RuntimePlatform.OSXEditor)
plugins.Push("Unity.Barracuda.MacBLAS");
while (plugins.Count > 0)
{
var candidate = plugins.Pop();
foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
{
var t = assembly.GetType(candidate);
if (t != null)
{
try
{
var inst = Activator.CreateInstance(t) as BLASPlugin;
if (inst != null && inst.IsCurrentPlatformSupported())
{
blas = inst;
}
}
catch (Exception e)
{
D.LogWarning($"Failed to load {t} with exception {e}");
break;
}
}
}
// Found working candidate
if (blas != null)
break;
}
return blas;
}
}
}