Files
unity-application/Packages/com.unity.barracuda/Runtime/Core/Backends/BarracudaBurstCPU.Jobs.cs
2023-03-18 19:53:17 +00:00

1647 lines
66 KiB
C#

using UnityEngine;
using System;
using System.Collections.Generic;
using System.Threading;
using Unity.Collections;
using Unity.Collections.LowLevel.Unsafe;
using Unity.Burst;
using Unity.Jobs;
using Unity.Jobs.LowLevel.Unsafe;
using Unity.Mathematics;
[assembly: BurstCompile(OptimizeFor = OptimizeFor.FastCompilation)]
namespace Unity.Barracuda {
// BarracudaBurstCPU.Core.cs -- definition of class BurstCPUOps, Pin(), BurstTensorData
// BarracudaBurstCPU.Ops.cs -- impl. IOps, job schedulers
// BarracudaBurstCPU.Jobs.cs -- impl. jobs
public partial class BurstCPUOps
{
internal static readonly Thread MainThread = Thread.CurrentThread;
#region Job resources declaration
internal unsafe struct ReadOnlyMemResource
{
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public void* ptr;
public float* ptrfloat { get { return (float*)ptr; } }
public half* ptrhalf { get { return (half*)ptr; } }
}
internal unsafe struct ReadWriteMemResource
{
[NoAlias][NativeDisableUnsafePtrRestriction] public void* ptr;
public float* ptrfloat { get { return (float*)ptr; } }
public half* ptrhalf { get { return (half*)ptr; } }
}
internal interface IJobResourceDeclarationO
{
ReadWriteMemResource O { get; set; }
}
internal interface IJobResourceDeclarationXO
{
ReadOnlyMemResource X { get; set; }
ReadWriteMemResource O { get; set; }
}
internal interface IJobResourceDeclarationXBO
{
ReadOnlyMemResource X { get; set; }
ReadOnlyMemResource B { get; set; }
ReadWriteMemResource O { get; set; }
}
internal interface IJobResourceDeclarationXSBO
{
ReadOnlyMemResource X { get; set; }
ReadOnlyMemResource S { get; set; }
ReadOnlyMemResource B { get; set; }
ReadWriteMemResource O { get; set; }
}
#endregion
#region Job inner data declaration
internal partial struct HardSigmoidJobHelper
{
[ReadOnly] public float alpha, beta;
}
internal partial struct ClipJobHelper
{
[ReadOnly] public float min, max;
}
internal partial struct PowJobHelper
{
[ReadOnly] public float alpha;
}
internal partial struct EluJobHelper
{
[ReadOnly] public float alpha;
}
internal partial struct SeluJobHelper
{
[ReadOnly] public float alpha, gamma;
}
internal partial struct PReluJobHelper
{
[ReadOnly] public int inOutChannels;
[ReadOnly] public int isGammaAVector; //1 if true, 0 if false
}
internal partial struct LeakyReluJobHelper
{
// from Theano impl
// https://github.com/Theano/theano/blob/d395439aec5a6ddde8ef5c266fd976412a5c5695/theano/tensor/nnet/nnet.py#L2209-L2251
[ReadOnly] public float f1, f2, alpha_;
public float alpha { get { return alpha_; } set {
alpha_ = value;
f1 = 0.5f * (1f + alpha_);
f2 = 0.5f * (1f - alpha_);
} }
}
internal partial struct CopyJobHelper
{
[ReadOnly] public int length;
}
internal partial struct CopyStrideJobHelper
{
[ReadOnly] public int XStride;
[ReadOnly] public int OStride;
[ReadOnly] public int count;
[ReadOnly] public int length;
}
internal partial struct GenericSliceJobHelper
{
[ReadOnly] public TensorShape shapeO;
[ReadOnly] public TensorShape shapeX;
[ReadOnly] public int strideS, strideR, strideN, strideT;
[ReadOnly] public int strideD, strideH, strideW, strideC;
[ReadOnly] public int startS, startR, startN, startT;
[ReadOnly] public int startD, startH, startW, startC;
}
internal partial struct GenericStridedSliceJobHelper
{
[ReadOnly] public TensorShape shapeO;
[ReadOnly] public TensorShape shapeX;
[ReadOnly] public int strideS, strideR, strideN, strideT;
[ReadOnly] public int strideD, strideH, strideW, strideC;
[ReadOnly] public int startS, startR, startN, startT;
[ReadOnly] public int startD, startH, startW, startC;
}
internal partial struct Border2DJobHelper
{
[ReadOnly] public TensorShape shapeO;
[ReadOnly] public TensorShape shapeX;
[ReadOnly] public int PadWidth;
[ReadOnly] public int PadHeight;
[ReadOnly] public int PadChannels;
[ReadOnly] public int CroppedWidth;
[ReadOnly] public int CroppedHeight;
[ReadOnly] public int CroppedChannels;
[ReadOnly] public float Beta;
}
internal unsafe partial struct TransposeJobHelper
{
[ReadOnly] public TensorShape shapeO;
[ReadOnly] public TensorShape shapeX;
[ReadOnly] public fixed int permutations[8];
}
internal partial struct Pad2DEdgeJobHelper
{
[ReadOnly] public TensorShape shapeO;
[ReadOnly] public TensorShape shapeX;
[ReadOnly] public int PadWidth;
[ReadOnly] public int PadHeight;
[ReadOnly] public int PadChannels;
}
internal partial struct Pad2DReflectJobHelper
{
[ReadOnly] public TensorShape shapeO;
[ReadOnly] public TensorShape shapeX;
[ReadOnly] public int PadWidth;
[ReadOnly] public int PadHeight;
[ReadOnly] public int PadChannels;
}
internal partial struct Pad2DSymmetricJobHelper
{
[ReadOnly] public TensorShape shapeO;
[ReadOnly] public TensorShape shapeX;
[ReadOnly] public int PadWidth;
[ReadOnly] public int PadHeight;
[ReadOnly] public int PadChannels;
}
internal partial struct TileJobHelper
{
[ReadOnly] public TensorShape shapeO;
[ReadOnly] public TensorShape shapeX;
}
internal partial struct GatherJobHelper
{
[ReadOnly] public TensorShape shapeO;
[ReadOnly] public TensorShape shapeX;
[ReadOnly] public int axis;
}
internal partial struct OneHotJobHelper
{
[ReadOnly] public TensorShape shapeO;
[ReadOnly] public TensorShape shapeX;
[ReadOnly] public int depth;
[ReadOnly] public int inputRank;
[ReadOnly] public float onValue;
[ReadOnly] public float offValue;
}
internal partial struct RandomNormalJobHelper
{
public Unity.Mathematics.Random rng;
public float mean;
public float scale;
}
internal partial struct RandomUniformJobHelper
{
public Unity.Mathematics.Random rng;
public float mean;
public float scale;
}
internal partial struct TestXOJobHelper
{
public int offset;
public float bias;
}
internal partial struct TestXBOJobHelper
{
public int offset;
}
internal partial struct VectorBroadcastScaleBiasJobHelper
{
[ReadOnly] public int inOutChannels;
[ReadOnly] public float alpha;
}
internal partial struct DepthwiseConv2DJobHelper
{
[ReadOnly] public int strideX, strideY, padX, padY;
[ReadOnly] public int inHeight, inWidth, inChannels, inStrideN, inStrideH, inStrideW;
[ReadOnly] public int kernelCount, kernelHeight, kernelWidth, kernelStrideH, kernelStrideW;
[ReadOnly] public int outBatch, outWidth, outStrideN, outStrideH, outStrideW;
}
internal partial struct Dense3JobHelper
{
public int AM, AN;
public int BM, BN;
public int SM, SN;
public int dispatchThreadX, dispatchThreadY, dispatchThreadZ;
}
internal partial struct ReduceMaxJobHelper
{
[ReadOnly] public int offsetReduce;
[ReadOnly] public int reduceDim;
}
internal partial struct ReduceSumJobHelper
{
[ReadOnly] public int offsetReduce;
[ReadOnly] public int reduceDim;
}
internal partial struct ReduceMeanJobHelper
{
[ReadOnly] public int offsetReduce;
[ReadOnly] public int reduceDim;
}
internal partial struct ExpBiasReduceJobHelper
{
[ReadOnly] public int offsetReduce;
[ReadOnly] public int reduceDim;
}
internal partial struct SoftmaxEndJobHelper
{
[ReadOnly] public int offsetReduce;
[ReadOnly] public int reduceDim;
}
internal partial struct LogSoftmaxEndJobHelper
{
[ReadOnly] public int offsetReduce;
[ReadOnly] public int reduceDim;
}
internal partial struct MaxPool2DJobHelper
{
[ReadOnly] public int strideX, strideY, padX, padY;
[ReadOnly] public int kernelHeight, kernelWidth;
[ReadOnly] public int inHeight, inWidth, inChannels, inStrideN, inStrideH, inStrideW;
[ReadOnly] public int outBatch, outWidth, outStrideN, outStrideH, outStrideW;
}
internal partial struct AvgPool2DJobHelper
{
[ReadOnly] public int strideX, strideY, padX, padY;
[ReadOnly] public int kernelHeight, kernelWidth;
[ReadOnly] public int inHeight, inWidth, inChannels, inStrideN, inStrideH, inStrideW;
[ReadOnly] public int outBatch, outWidth, outStrideN, outStrideH, outStrideW;
}
#endregion
static unsafe float* AllocBlock(int blockSizeM, int blockSizeN)
{
int sz = blockSizeM * blockSizeN * sizeof(float);
// Allocator.Temp is the fastest allocator, but can only be used within jobs; No explicit need to deallocate
// Source: https://docs.unity3d.com/Packages/com.unity.collections@1.0/manual/allocation.html#allocatortemp
return (float*)UnsafeUtility.Malloc(sz, JobsUtility.CacheLineSize, Allocator.Temp);
}
static unsafe half* AllocBlockHalf(int blockSizeM, int blockSizeN)
{
int sz = blockSizeM * blockSizeN * sizeof(half);
// Allocator.Temp is the fastest allocator, but can only be used within jobs; No explicit need to deallocate
// Source: https://docs.unity3d.com/Packages/com.unity.collections@1.0/manual/allocation.html#allocatortemp
return (half*)UnsafeUtility.Malloc(sz, JobsUtility.CacheLineSize, Allocator.Temp);
}
static unsafe void FreeBlock(void* ptr)
{
// We are using Allocator.Temp, so there is no explicit need to deallocate
// if (ptr != null)
// UnsafeUtility.Free(ptr, Allocator.Temp);
}
static unsafe void CopyBlock(float* blockOut, float* matrixIn, int row, int M, int col, int N, int blockSizeM, int blockSizeN)
{
var rowFinal = Math.Min(row + blockSizeM, M);
var count = Math.Min(col + blockSizeN, N) - col;
for (var i = row; i < rowFinal; i++)
MatrixUtils.CopyFloatArray(blockOut + (i - row) * blockSizeN, matrixIn + i * N + col, count);
}
static unsafe int CopyBlockWithPadding(float* matrixIn, int row, int M, int col, int N, float* blockOut, int blockSizeM, int blockSizeN, bool transpose = false)
{
MatrixUtils.ClearFloatArray(blockOut, 0, blockSizeM * blockSizeN);
var blockOutStride = blockSizeN;
var rowFinal = Math.Min(row + blockSizeM, M);
var count = Math.Min(col + blockSizeN, N) - col;
// @TODO: measure which one is better - sequential access over matrix memory or blockOut cache
if (transpose)
{
// sequential access over matrixIn, strided over blockOut
for (var j = 0; j < count; ++j)
for (var i = row; i < rowFinal; i++)
blockOut[(i - row) * blockOutStride + j] = matrixIn[i + (col + j) * M];
}
else
for (var i = row; i < rowFinal; i++)
{
MatrixUtils.CopyFloatArray(matrixIn + i * N + col, blockOut + (i - row) * blockOutStride, count);
}
return blockOutStride;
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
internal unsafe struct MatrixMultiplyJob : IJobParallelFor
{
// Convention: M x N matrices (other areas in our code may be N x M)
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* A;
public int AM, AN;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* B;
public int BM, BN;
[NoAlias][NativeDisableUnsafePtrRestriction] public unsafe float* C;
public int CM, CN;
public bool transposeA;
public bool transposeB;
public int blockSizeM;
public int blockSizeN;
public int blockSizeK;
public JobHandle Schedule(JobHandle dependsOn)
{
return Schedule(blocksBatchCount:1, dependsOn);
}
public JobHandle Schedule(int blocksBatchCount, JobHandle dependsOn)
{
if (transposeA)
{
int tmp = AM; AM = AN; AN = tmp;
}
if (transposeB)
{
int tmp = BM; BM = BN; BN = tmp;
}
// TODO: Determine optimal kernel / block sizes for mobile/console; This code path is currently not used
// in production and instead MatrixMultiplyLegacyJob; However, this kernel size seemed to work best with
// mobile; An alternative is have codegen generate the whole job + kernel, so we can switch dynamically
// at runtime.
#if UNITY_ANDROID || UNITY_IOS || UNITY_WSA || UNITY_PS4 || UNITY_PS5 || UNITY_XBOXONE
if (blockSizeM == 0 || blockSizeN == 0 || blockSizeK == 0)
{
blockSizeM = 64;
blockSizeN = 64;
blockSizeK = 16;
}
#else
if (blockSizeM == 0 || blockSizeN == 0 || blockSizeK == 0)
{
// Profiling across a range of matrices for best block size revealed:
// (32, 384, 16) was the best common block size for matrices <= 576
// (32, 768, 32) for matrices > 576 and <= 1152
// (64, 96, 32) for matrices > 1200
int maxM = 32;
int maxN = 384;
int maxK = 16;
if (AM > 1200)
{
maxM = 64;
maxN = 96;
maxK = 32;
}
else if (AM > 576)
{
maxM = 32;
maxN = 768;
maxK = 32;
}
blockSizeM = Mathf.Min(AM, maxM);
const int kernelWidth = 24;
var sizeN = Mathf.ClosestPowerOfTwo(AN);
sizeN = (sizeN / kernelWidth) * kernelWidth;
sizeN = Mathf.Max(sizeN, kernelWidth);
blockSizeN = Mathf.Min(sizeN, maxN);
// Adjust block size down to the actual count of rows, so no allocation takes place needlessly
blockSizeK = Mathf.Min(BM, maxK);
}
#endif
// Distribute jobs over a single axis
int longerAxis = AM;
int blockSizeForLongerAxis = blockSizeM;
if (BN > AM)
{
longerAxis = BN; blockSizeForLongerAxis = blockSizeN;
}
var workElements = (longerAxis + blockSizeForLongerAxis - 1) / blockSizeForLongerAxis;
return IJobParallelForExtensions.Schedule(this, workElements, blocksBatchCount, dependsOn);
}
public void Execute(int i)
{
int shorterAxis = BN;
int blockSizeForShorterAxis = blockSizeN;
if (BN > AM)
{
shorterAxis = AM; blockSizeForShorterAxis = blockSizeM;
}
float* blockTempA = null;
float* blockTempB = null;
float* blockTempC = null;
// this job is scheduled over the Max(AN, BM)
// need to pick the remaining (shorter) axis
for (int j = 0; j < shorterAxis; j += blockSizeForShorterAxis)
{
int rowA = (AM >= BN) ? i * blockSizeM: j;
int colB = (AM >= BN) ? j : i * blockSizeN;
float* blockC = C + rowA * CN + colB;
int strideC = CN;
if (rowA + blockSizeM > CM || colB + blockSizeN > CN) // copy remainder of C into zero-padded block
{
if (blockTempC == null)
blockTempC = AllocBlock(blockSizeM, blockSizeN);
blockC = blockTempC;
strideC = CopyBlockWithPadding(C, rowA, CM, colB, CN, blockC, blockSizeM, blockSizeN);
}
for (int l = 0; l < AN; l += blockSizeK) // inner-loop
{
float* blockA = A + rowA * AN + l;
float* blockB = B + l * BN + colB;
int strideA = AN;
int strideB = BN;
if (rowA + blockSizeM > AM || l + blockSizeK > AN || transposeA) // copy remainder of A or transposed A into zero-padded block
{
if (blockTempA == null)
blockTempA = AllocBlock(blockSizeM, blockSizeK);
blockA = blockTempA;
strideA = CopyBlockWithPadding(A, rowA, AM, l, AN, blockA, blockSizeM, blockSizeK, transposeA);
}
if (colB + blockSizeN > BN || l + blockSizeK > BM || transposeB) // copy remainder of A or transposed A into zero-padded block
{
if (blockTempB == null)
blockTempB = AllocBlock(blockSizeK, blockSizeN);
blockB = blockTempB;
strideB = CopyBlockWithPadding(B, l, BM, colB, BN, blockB, blockSizeK, blockSizeN, transposeB);
}
// Use defines instead of Application.isMobilePlatform || Application.isConsolePlatform, so we don't interrupt Burst
// inlining or introduce a branch here in the inner loop
#if UNITY_ANDROID || UNITY_IOS || UNITY_WSA || UNITY_PS4 || UNITY_PS5 || UNITY_XBOXONE
MultiplyBlockUnroll1x8(blockA, strideA, blockB, strideB, blockC, strideC,
blockSizeM, blockSizeK, Math.Min(blockSizeN, BN - colB));
#else
MultiplyBlockUnroll3x24(blockA, strideA, blockB, strideB, blockC, strideC,
blockSizeM, blockSizeK, Math.Min(blockSizeN, BN - colB));
#endif
}
if (blockC == blockTempC) // copy back
CopyBlock(blockC, C, rowA, CM, colB, CN, blockSizeM, blockSizeN);
FreeBlock(blockTempA);
FreeBlock(blockTempB);
FreeBlock(blockTempC);
}
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
unsafe struct MatrixMultiplyLegacyJob : IJobParallelFor
{
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* A;
public int AM, AN;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* B;
public int BM, BN;
[NoAlias][NativeDisableUnsafePtrRestriction] public unsafe float* C;
public int CM, CN;
public bool transposeA;
public bool transposeB;
public const int blockSize = 16;
public JobHandle Schedule(JobHandle dependsOn)
{
return Schedule(blocksBatchCount:1, dependsOn);
}
public JobHandle Schedule(int blocksBatchCount, JobHandle dependsOn)
{
if (transposeA)
{
int tmp = AM; AM = AN; AN = tmp;
}
if (transposeB)
{
int tmp = BM; BM = BN; BN = tmp;
}
int n = math.max(AM, BN);
int workElements = (n + blockSize - 1) / blockSize;
return IJobParallelForExtensions.Schedule(this, workElements, blocksBatchCount, dependsOn);
}
public void Execute(int i)
{
int bs = blockSize;
unsafe
{
float* blockTempA = null;
float* blockTempB = null;
float* blockTempC = null;
// this job is scheduled over the Max(AN, BM)
// need to pick the remaining (shorter) axis
for (int j = 0; j < Math.Min(AM, BN); j += bs)
{
int rowA = (AM > BN) ? i * bs: j;
int colB = (AM > BN) ? j : i * bs;
float* blockC = C + rowA * CN + colB;
int strideC = CN;
if (rowA + bs > CM || colB + bs > CN) // copy remainder of C into zero-padded block
{
if (blockTempC == null)
blockTempC = AllocBlock();
blockC = blockTempC;
strideC = bs;
MatrixUtils.CopyBlockWithPadding(C, rowA, CM, colB, CN, blockC, bs);
}
for (int l = 0; l < AN; l += bs) // inner-loop
{
float* blockA = A + rowA * AN + l;
float* blockB = B + l * BN + colB;
int strideA = AN;
int strideB = BN;
if (rowA + bs > AM || l + bs > AN || transposeA) // copy remainder of A or transposed A into zero-padded block
{
if (blockTempA == null)
blockTempA = AllocBlock();
blockA = blockTempA;
strideA = bs;
MatrixUtils.CopyBlockWithPadding(A, rowA, AM, l, AN, blockA, bs, transposeA);
}
if (colB + bs > BN || l + bs > BM || transposeB) // copy remainder of A or transposed A into zero-padded block
{
if (blockTempB == null)
blockTempB = AllocBlock();
blockB = blockTempB;
strideB = bs;
MatrixUtils.CopyBlockWithPadding(B, l, BM, colB, BN, blockB, bs, transposeB);
}
MultiplyBlockUnrollHx16(blockA, strideA, blockB, strideB, blockC, strideC);
}
if (blockC == blockTempC) // copy back
MatrixUtils.CopyBlockWithPadding(blockC, C, rowA, CM, colB, CN, bs);
}
FreeBlock(blockTempA);
FreeBlock(blockTempB);
FreeBlock(blockTempC);
}
}
static unsafe float* AllocBlock()
{
const int sz = blockSize * blockSize * sizeof(float);
return (float*)UnsafeUtility.Malloc(sz, JobsUtility.CacheLineSize, Allocator.TempJob);
}
static unsafe void FreeBlock(float* ptr)
{
if (ptr != null)
UnsafeUtility.Free(ptr, Allocator.TempJob);
}
static unsafe void MultiplyBlockUnrollHx16(float* Ap, int Astride, float* Bp, int Bstride, float* Cp, int Cstride)
{
for (int i = 0; i < blockSize; i++)
{
for (int j = 0; j < blockSize; j += 16)
{
int baseC = i * Cstride + j;
float sum0 = *(Cp + baseC + 0);
float sum1 = *(Cp + baseC + 1);
float sum2 = *(Cp + baseC + 2);
float sum3 = *(Cp + baseC + 3);
float sum4 = *(Cp + baseC + 4);
float sum5 = *(Cp + baseC + 5);
float sum6 = *(Cp + baseC + 6);
float sum7 = *(Cp + baseC + 7);
float sum8 = *(Cp + baseC + 8);
float sum9 = *(Cp + baseC + 9);
float sumA = *(Cp + baseC +10);
float sumB = *(Cp + baseC +11);
float sumC = *(Cp + baseC +12);
float sumD = *(Cp + baseC +13);
float sumE = *(Cp + baseC +14);
float sumF = *(Cp + baseC +15);
for (int l = 0; l < blockSize; l++)
{
float A = *(Ap + i * Astride + l);
int baseB = l * Bstride + j;
sum0 += A * (*(Bp + baseB + 0));
sum1 += A * (*(Bp + baseB + 1));
sum2 += A * (*(Bp + baseB + 2));
sum3 += A * (*(Bp + baseB + 3));
sum4 += A * (*(Bp + baseB + 4));
sum5 += A * (*(Bp + baseB + 5));
sum6 += A * (*(Bp + baseB + 6));
sum7 += A * (*(Bp + baseB + 7));
sum8 += A * (*(Bp + baseB + 8));
sum9 += A * (*(Bp + baseB + 9));
sumA += A * (*(Bp + baseB +10));
sumB += A * (*(Bp + baseB +11));
sumC += A * (*(Bp + baseB +12));
sumD += A * (*(Bp + baseB +13));
sumE += A * (*(Bp + baseB +14));
sumF += A * (*(Bp + baseB +15));
}
*(Cp + baseC + 0) = sum0;
*(Cp + baseC + 1) = sum1;
*(Cp + baseC + 2) = sum2;
*(Cp + baseC + 3) = sum3;
*(Cp + baseC + 4) = sum4;
*(Cp + baseC + 5) = sum5;
*(Cp + baseC + 6) = sum6;
*(Cp + baseC + 7) = sum7;
*(Cp + baseC + 8) = sum8;
*(Cp + baseC + 9) = sum9;
*(Cp + baseC +10) = sumA;
*(Cp + baseC +11) = sumB;
*(Cp + baseC +12) = sumC;
*(Cp + baseC +13) = sumD;
*(Cp + baseC +14) = sumE;
*(Cp + baseC +15) = sumF;
}
}
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
unsafe struct MatrixMultiply3x2Job : IJobParallelFor, IJobResourceDeclarationXBO
{
public ReadOnlyMemResource X { get; set; } float* Aptr => X.ptrfloat;
public ReadOnlyMemResource B { get; set; } float* Bptr => B.ptrfloat;
public ReadWriteMemResource O { get; set; } float* Cptr => O.ptrfloat;
public int AM, AN;
public int BM, BN;
public int CM, CN;
public int dispatchThreadX, dispatchThreadY, dispatchThreadZ;
public const int blockSize = 16;
public void Execute(int threadID)
{
int dispatchThreadXY = dispatchThreadX * dispatchThreadY;
int batch = (threadID / dispatchThreadXY);
int i = (threadID % dispatchThreadXY) % dispatchThreadX;
int j = (threadID % dispatchThreadXY) / dispatchThreadX;
int batchOffSetA = (batch * AM * AN);
int batchOffSetC = (batch * CM * CN);
int rowA = i * blockSize;
int colB = j * blockSize;
unsafe
{
float* blockTempA = null;
float* blockTempB = null;
float* blockTempC = null;
float* blockC = Cptr + rowA + CM * colB + batchOffSetC;
int strideC = CM;
if (rowA + blockSize > CM || colB + blockSize > CN) // copy remainder of C into zero-padded block
{
blockTempC = AllocBlock(blockSize, blockSize);
strideC = blockSize;
blockC = blockTempC;
}
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockC[x + strideC * y] = 0.0f;
for (int l = 0; l < AN; l += blockSize) // inner-loop
{
float* blockA = Aptr + rowA + AM * l + batchOffSetA;
float* blockB = Bptr + l * BN + colB;
int strideA = AM;
int strideB = BN;
if (rowA + blockSize > AM || l + blockSize > AN) // copy remainder of A into zero-padded block
{
if (blockTempA == null)
blockTempA = AllocBlock(blockSize, blockSize);
strideA = blockSize;
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockTempA[x + blockSize * y] = ((rowA + x) < AM && (l + y < AN)) ? blockA[x + AM * y] : 0.0f;
blockA = blockTempA;
}
if (colB + blockSize > BN || l + blockSize > BM) // copy remainder of B into zero-padded block
{
if (blockTempB == null)
blockTempB = AllocBlock(blockSize, blockSize);
strideB = blockSize;
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockTempB[x + blockSize * y] = ((colB + x) < BN && (l + y < BM)) ? blockB[x + BN * y] : 0.0f;
blockB = blockTempB;
}
MultiplyBlockUnrollHx16(blockA, strideA, blockB, strideB, blockC, strideC);
}
if (blockC == blockTempC) // copy back
{
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
{
if (((rowA + x) < CM) && ((colB + y) < CN))
Cptr[(rowA + x) + CM * (colB + y) + batchOffSetC] = blockTempC[x + blockSize * y];
}
}
FreeBlock(blockTempA);
FreeBlock(blockTempB);
FreeBlock(blockTempC);
}
}
static void MultiplyBlockUnrollHx16(float* Ap, int Astride, float* Bp, int Bstride, float* Cp, int Cstride)
{
for (int i = 0; i < blockSize; i++)
{
float sum0 = *(Cp + i + Cstride * 0);
float sum1 = *(Cp + i + Cstride * 1);
float sum2 = *(Cp + i + Cstride * 2);
float sum3 = *(Cp + i + Cstride * 3);
float sum4 = *(Cp + i + Cstride * 4);
float sum5 = *(Cp + i + Cstride * 5);
float sum6 = *(Cp + i + Cstride * 6);
float sum7 = *(Cp + i + Cstride * 7);
float sum8 = *(Cp + i + Cstride * 8);
float sum9 = *(Cp + i + Cstride * 9);
float sumA = *(Cp + i + Cstride * 10);
float sumB = *(Cp + i + Cstride * 11);
float sumC = *(Cp + i + Cstride * 12);
float sumD = *(Cp + i + Cstride * 13);
float sumE = *(Cp + i + Cstride * 14);
float sumF = *(Cp + i + Cstride * 15);
for (int l = 0; l < blockSize; l++)
{
float A = *(Ap + i + Astride * l);
float B0 = *(Bp + l * Bstride + 0);
float B1 = *(Bp + l * Bstride + 1);
float B2 = *(Bp + l * Bstride + 2);
float B3 = *(Bp + l * Bstride + 3);
float B4 = *(Bp + l * Bstride + 4);
float B5 = *(Bp + l * Bstride + 5);
float B6 = *(Bp + l * Bstride + 6);
float B7 = *(Bp + l * Bstride + 7);
float B8 = *(Bp + l * Bstride + 8);
float B9 = *(Bp + l * Bstride + 9);
float BA = *(Bp + l * Bstride + 10);
float BB = *(Bp + l * Bstride + 11);
float BC = *(Bp + l * Bstride + 12);
float BD = *(Bp + l * Bstride + 13);
float BE = *(Bp + l * Bstride + 14);
float BF = *(Bp + l * Bstride + 15);
sum0 += A * B0;
sum1 += A * B1;
sum2 += A * B2;
sum3 += A * B3;
sum4 += A * B4;
sum5 += A * B5;
sum6 += A * B6;
sum7 += A * B7;
sum8 += A * B8;
sum9 += A * B9;
sumA += A * BA;
sumB += A * BB;
sumC += A * BC;
sumD += A * BD;
sumE += A * BE;
sumF += A * BF;
}
*(Cp + i + Cstride * 0 ) = sum0;
*(Cp + i + Cstride * 1 ) = sum1;
*(Cp + i + Cstride * 2 ) = sum2;
*(Cp + i + Cstride * 3 ) = sum3;
*(Cp + i + Cstride * 4 ) = sum4;
*(Cp + i + Cstride * 5 ) = sum5;
*(Cp + i + Cstride * 6 ) = sum6;
*(Cp + i + Cstride * 7 ) = sum7;
*(Cp + i + Cstride * 8 ) = sum8;
*(Cp + i + Cstride * 9 ) = sum9;
*(Cp + i + Cstride * 10) = sumA;
*(Cp + i + Cstride * 11) = sumB;
*(Cp + i + Cstride * 12) = sumC;
*(Cp + i + Cstride * 13) = sumD;
*(Cp + i + Cstride * 14) = sumE;
*(Cp + i + Cstride * 15) = sumF;
}
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
unsafe struct MatrixMultiply4x4Job : IJobParallelFor, IJobResourceDeclarationXBO
{
public ReadOnlyMemResource X { get; set; } float* Aptr => X.ptrfloat;
public ReadOnlyMemResource B { get; set; } float* Bptr => B.ptrfloat;
public ReadWriteMemResource O { get; set; } float* Cptr => O.ptrfloat;
public int AB0, AB1, AM, AN;
public int BB0, BB1, BM, BN;
public int CB1, CM, CN;
public int dispatchThreadX, dispatchThreadY, dispatchThreadZ;
public const int blockSize = 16;
public void Execute(int threadID)
{
int dispatchThreadXY = dispatchThreadX * dispatchThreadY;
int batch1 = (threadID % CB1);
int batch0 = (threadID / CB1) / dispatchThreadXY;
int i = ((threadID / CB1) % dispatchThreadXY) % dispatchThreadX;
int j = ((threadID / CB1) % dispatchThreadXY) / dispatchThreadX;
int batchOffSetA = ((batch0 % AB0) * AM * AN * AB1 + (batch1 % AB1));
int batchOffSetB = ((batch0 % BB0) * BM * BN * BB1 + (batch1 % BB1));
int batchOffSetC = (batch0 * CM * CN * CB1 + batch1);
int rowA = i * blockSize;
int colB = j * blockSize;
unsafe
{
float* blockTempA = null;
float* blockTempB = null;
float* blockTempC = null;
float* blockC = Cptr + (rowA * CN + colB)*CB1 + batchOffSetC;
int strideC = CN;
int strideBatchC = CB1;
if (rowA + blockSize > CM || colB + blockSize > CN) // copy remainder of A into zero-padded block
{
blockTempC = AllocBlock(blockSize, blockSize);
strideC = blockSize;
strideBatchC = 1;
blockC = blockTempC;
}
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockC[(x + strideC * y) * strideBatchC] = 0.0f;
for (int l = 0; l < AN; l += blockSize) // inner-loop
{
float* blockA = Aptr + (rowA * AN + l)*AB1 + batchOffSetA;
float* blockB = Bptr + (l * BN + colB)*BB1 + batchOffSetB;
int strideA = AN;
int strideBatchA = AB1;
int strideB = BN;
int strideBatchB = BB1;
if (rowA + blockSize > AM || l + blockSize > AN) // copy remainder of A into zero-padded block
{
if (blockTempA == null)
blockTempA = AllocBlock(blockSize, blockSize);
strideA = blockSize;
strideBatchA = 1;
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockTempA[x + blockSize * y] = ((rowA + y) < AM && (l + x < AN)) ? blockA[(x + AN * y)*AB1] : 0.0f;
blockA = blockTempA;
}
if (colB + blockSize > BN || l + blockSize > BM) // copy remainder of A into zero-padded block
{
if (blockTempB == null)
blockTempB = AllocBlock(blockSize, blockSize);
strideB = blockSize;
strideBatchB = 1;
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockTempB[x + blockSize * y] = ((colB + x) < BN && (l + y < BM)) ? blockB[(x + BN * y)*BB1] : 0.0f;
blockB = blockTempB;
}
MultiplyBlockUnrollHx16(blockA, strideA, strideBatchA, blockB, strideB, strideBatchB, blockC, strideC, strideBatchC);
}
if (blockC == blockTempC) // copy back
{
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
{
if (((rowA + y) < CM) && (colB + x < CN))
Cptr[((rowA + y) * CN + (colB + x)) * CB1 + batchOffSetC] = blockTempC[x + blockSize * y];
}
}
FreeBlock(blockTempA);
FreeBlock(blockTempB);
FreeBlock(blockTempC);
}
}
static void MultiplyBlockUnrollHx16(float* Ap, int Astride, int ABatchStride, float* Bp, int Bstride, int BBatchStride, float* Cp, int Cstride, int CBatchStride)
{
for (int i = 0; i < blockSize; i++)
{
float sum0 = *(Cp + (i * Cstride + 0 )*CBatchStride);
float sum1 = *(Cp + (i * Cstride + 1 )*CBatchStride);
float sum2 = *(Cp + (i * Cstride + 2 )*CBatchStride);
float sum3 = *(Cp + (i * Cstride + 3 )*CBatchStride);
float sum4 = *(Cp + (i * Cstride + 4 )*CBatchStride);
float sum5 = *(Cp + (i * Cstride + 5 )*CBatchStride);
float sum6 = *(Cp + (i * Cstride + 6 )*CBatchStride);
float sum7 = *(Cp + (i * Cstride + 7 )*CBatchStride);
float sum8 = *(Cp + (i * Cstride + 8 )*CBatchStride);
float sum9 = *(Cp + (i * Cstride + 9 )*CBatchStride);
float sumA = *(Cp + (i * Cstride + 10)*CBatchStride);
float sumB = *(Cp + (i * Cstride + 11)*CBatchStride);
float sumC = *(Cp + (i * Cstride + 12)*CBatchStride);
float sumD = *(Cp + (i * Cstride + 13)*CBatchStride);
float sumE = *(Cp + (i * Cstride + 14)*CBatchStride);
float sumF = *(Cp + (i * Cstride + 15)*CBatchStride);
for (int l = 0; l < blockSize; l++)
{
float A = *(Ap + (i * Astride + l)*ABatchStride);
float B0 = *(Bp + (l * Bstride + 0 )*BBatchStride);
float B1 = *(Bp + (l * Bstride + 1 )*BBatchStride);
float B2 = *(Bp + (l * Bstride + 2 )*BBatchStride);
float B3 = *(Bp + (l * Bstride + 3 )*BBatchStride);
float B4 = *(Bp + (l * Bstride + 4 )*BBatchStride);
float B5 = *(Bp + (l * Bstride + 5 )*BBatchStride);
float B6 = *(Bp + (l * Bstride + 6 )*BBatchStride);
float B7 = *(Bp + (l * Bstride + 7 )*BBatchStride);
float B8 = *(Bp + (l * Bstride + 8 )*BBatchStride);
float B9 = *(Bp + (l * Bstride + 9 )*BBatchStride);
float BA = *(Bp + (l * Bstride + 10)*BBatchStride);
float BB = *(Bp + (l * Bstride + 11)*BBatchStride);
float BC = *(Bp + (l * Bstride + 12)*BBatchStride);
float BD = *(Bp + (l * Bstride + 13)*BBatchStride);
float BE = *(Bp + (l * Bstride + 14)*BBatchStride);
float BF = *(Bp + (l * Bstride + 15)*BBatchStride);
sum0 += A * B0;
sum1 += A * B1;
sum2 += A * B2;
sum3 += A * B3;
sum4 += A * B4;
sum5 += A * B5;
sum6 += A * B6;
sum7 += A * B7;
sum8 += A * B8;
sum9 += A * B9;
sumA += A * BA;
sumB += A * BB;
sumC += A * BC;
sumD += A * BD;
sumE += A * BE;
sumF += A * BF;
}
*(Cp + (i * Cstride + 0 )*CBatchStride) = sum0;
*(Cp + (i * Cstride + 1 )*CBatchStride) = sum1;
*(Cp + (i * Cstride + 2 )*CBatchStride) = sum2;
*(Cp + (i * Cstride + 3 )*CBatchStride) = sum3;
*(Cp + (i * Cstride + 4 )*CBatchStride) = sum4;
*(Cp + (i * Cstride + 5 )*CBatchStride) = sum5;
*(Cp + (i * Cstride + 6 )*CBatchStride) = sum6;
*(Cp + (i * Cstride + 7 )*CBatchStride) = sum7;
*(Cp + (i * Cstride + 8 )*CBatchStride) = sum8;
*(Cp + (i * Cstride + 9 )*CBatchStride) = sum9;
*(Cp + (i * Cstride + 10)*CBatchStride) = sumA;
*(Cp + (i * Cstride + 11)*CBatchStride) = sumB;
*(Cp + (i * Cstride + 12)*CBatchStride) = sumC;
*(Cp + (i * Cstride + 13)*CBatchStride) = sumD;
*(Cp + (i * Cstride + 14)*CBatchStride) = sumE;
*(Cp + (i * Cstride + 15)*CBatchStride) = sumF;
}
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
unsafe struct ConvertHalfToFloatJob : IJobParallelFor, IJobResourceDeclarationXO
{
public ReadOnlyMemResource X { get; set; } half* Xptr => X.ptrhalf;
public ReadWriteMemResource O { get; set; } float* Optr => O.ptrfloat;
public void Execute(int threadID)
{
Optr[threadID] = (float)(Xptr[threadID]);
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
unsafe struct ConvertFloatToHalfJob : IJobParallelFor, IJobResourceDeclarationXO
{
public ReadOnlyMemResource X { get; set; } float* Xptr => X.ptrfloat;
public ReadWriteMemResource O { get; set; } half* Optr => O.ptrhalf;
public void Execute(int threadID)
{
Optr[threadID] = (half)(Xptr[threadID]);
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
unsafe struct Im2ColSliceJob : IJobParallelFor, IJobResourceDeclarationXO
{
public ReadOnlyMemResource X { get; set; }
public ReadWriteMemResource O { get; set; }
[ReadOnly] public int inOutBatch, inOutChannels;
[ReadOnly] public int inHeight, inStrideN, inStrideH, inStrideW;
[ReadOnly] public int outWidth, outStrideN, outStrideH;
[ReadOnly] public int strideX, strideY, offsetY;
[ReadOnly] public int padLeft, padRight, skipFromInputRow, copyFromInputRow;
public void Execute(int y)
{
for (int n = 0; n < inOutBatch; ++n)
{
int readY = strideY * y + offsetY;
float* from = X.ptrfloat + n * inStrideN + readY * inStrideH + skipFromInputRow * inStrideW;
float* to = O.ptrfloat + n * outStrideN + y * outStrideH;
if (readY < 0 ||
readY >= inHeight)
{
// pad-0 top or bottom line, len = outWidth
UnsafeUtility.MemClear(destination: to,
size: inOutChannels * outWidth * sizeof(float));
to += inOutChannels * outWidth;
}
else
{
// pad-0 left, len = padLeft
UnsafeUtility.MemClear(destination: to,
size: inOutChannels * padLeft * sizeof(float));
to += inOutChannels * padLeft;
// copy from X with stride, if necessary
if (strideX == 1)
{
UnsafeUtility.MemCpy(destination: to,
source: from,
size: inOutChannels * copyFromInputRow * sizeof(float));
to += inOutChannels * copyFromInputRow;
}
else
{
UnsafeUtility.MemCpyStride(destination: to, destinationStride: inOutChannels * sizeof(float),
source: from, sourceStride: strideX * inOutChannels * sizeof(float),
elementSize: inOutChannels * sizeof(float),
count: copyFromInputRow);
to += inOutChannels * copyFromInputRow;
}
// pad-0 right, len = padRight
UnsafeUtility.MemClear(destination: to,
size: inOutChannels * padRight * sizeof(float));
to += inOutChannels * padRight;
}
}
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
unsafe struct ZeroBroadcastJob : IJob, IJobResourceDeclarationO
{
public ReadWriteMemResource O { get; set; }
[ReadOnly] public int repeat;
public void Execute()
{
UnsafeUtility.MemClear(destination: O.ptr, size: repeat * sizeof(float));
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
unsafe struct VectorBroadcastJob : IJob, IJobResourceDeclarationXO
{
public ReadOnlyMemResource X { get; set; }
public ReadWriteMemResource O { get; set; }
[ReadOnly] public int channels;
[ReadOnly] public int repeat;
public void Execute()
{
UnsafeUtility.MemCpyReplicate(destination: O.ptr,
source: X.ptr,
size: channels * sizeof(float),
count: repeat);
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
unsafe struct MemFreeJob : IJob
{
[NoAlias] [NativeDisableUnsafePtrRestriction] public void* buffer0;
[NoAlias] [NativeDisableUnsafePtrRestriction] public void* buffer1;
[ReadOnly] public Allocator allocator;
public void Execute()
{
if (buffer0 != null)
UnsafeUtility.Free(buffer0, allocator);
if (buffer1 != null)
UnsafeUtility.Free(buffer1, allocator);
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Default, FloatPrecision = FloatPrecision.Standard)]
unsafe struct LSTMEndJob : IJobParallelFor
{
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* i_mad_w;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* j_mad_w;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* f_mad_w;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* o_mad_w;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* i_mad_r;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* j_mad_r;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* f_mad_r;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* o_mad_r;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* cell;
[NoAlias][NativeDisableUnsafePtrRestriction] public unsafe float* O;
[NoAlias][NativeDisableUnsafePtrRestriction] public unsafe float* cell_out;
[NoAlias][NativeDisableUnsafePtrRestriction] public unsafe float* hidden_out;
public int sequenceIndexO, sequenceIndexI;
public int batchSize, hiddenSize;
public int batchSizeR;
public JobHandle Schedule(int arrayLength, int innerloopBatchCount, JobHandle dependsOn)
{
return IJobParallelForExtensions.Schedule(this, arrayLength, innerloopBatchCount, dependsOn);
}
public void Execute(int threadId)
{
int b_tID = (threadId / hiddenSize);
int h_tID = (threadId % hiddenSize);
int threadId_r = (b_tID % batchSizeR) * hiddenSize + h_tID;
float i_mad = i_mad_w[batchSize * hiddenSize * sequenceIndexI + threadId] + i_mad_r[threadId_r];
float j_mad = j_mad_w[batchSize * hiddenSize * sequenceIndexI + threadId] + j_mad_r[threadId_r];
float f_mad = f_mad_w[batchSize * hiddenSize * sequenceIndexI + threadId] + f_mad_r[threadId_r];
float o_mad = o_mad_w[batchSize * hiddenSize * sequenceIndexI + threadId] + o_mad_r[threadId_r];
float i = 1f / (1f + math.exp(-i_mad));
float j = math.tanh(j_mad);
float f = 1f / (1f + math.exp(-f_mad));
float o = 1f / (1f + math.exp(-o_mad));
float state_c_mul = cell[threadId_r] * f;
float i_j_mul = i * j;
float state_c = state_c_mul + i_j_mul;
float state_c_tanh = math.tanh(state_c);
float state_h = o * state_c_tanh;
O[batchSize * hiddenSize * sequenceIndexO + threadId] = state_h;
hidden_out[threadId] = state_h;
cell_out[threadId] = state_c;
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
unsafe struct LSTMDense3Job : IJobParallelFor
{
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* A;
public int AM, AN;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* B;
public int BM, BN;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* C;
public int CN;
[NoAlias][NativeDisableUnsafePtrRestriction] public unsafe float* S;
public int SM, SN;
public int dispatchThreadX, dispatchThreadY, dispatchThreadZ;
public const int blockSize = 16;
public JobHandle Schedule(JobHandle dependsOn)
{
return Schedule(blocksBatchCount:1, dependsOn);
}
public JobHandle Schedule(int blocksBatchCount, JobHandle dependsOn)
{
return IJobParallelForExtensions.Schedule(this, dispatchThreadX * dispatchThreadY * dispatchThreadZ, blocksBatchCount, dependsOn);
}
public void Execute(int threadID)
{
int dispatchThreadXY = dispatchThreadX * dispatchThreadY;
int batch = (threadID / dispatchThreadXY);
int i = (threadID % dispatchThreadXY) % dispatchThreadX;
int j = (threadID % dispatchThreadXY) / dispatchThreadX;
int batchOffSetA = (batch * AM * AN);
int batchOffSetS = (batch * SM * SN);
int rowA = i * blockSize;
int colB = j * blockSize;
unsafe
{
float* blockTempA = null;
float* blockTempB = null;
float* blockTempS = null;
float* blockS = S + rowA * SN + colB + batchOffSetS;
int strideS = SN;
if (rowA + blockSize > SM || colB + blockSize > SN) // copy remainder of C into zero-padded block
{
blockTempS = AllocBlock(blockSize, blockSize);
strideS = blockSize;
blockS = blockTempS;
}
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockS[x + strideS * y] = (colB + x) < BN ? C[(colB + x)%CN] : 0.0f;
for (int l = 0; l < AN; l += blockSize) // inner-loop
{
float* blockA = A + rowA * AN + l + batchOffSetA;
float* blockB = B + l * BN + colB;
int strideA = AN;
int strideB = BN;
if (rowA + blockSize > AM || l + blockSize > AN) // copy remainder of A into zero-padded block
{
if (blockTempA == null)
blockTempA = AllocBlock(blockSize, blockSize);
strideA = blockSize;
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockTempA[x + blockSize * y] = ((rowA + y) < AM && (l + x < AN)) ? blockA[x + AN * y] : 0.0f;
blockA = blockTempA;
}
if (colB + blockSize > BN || l + blockSize > BM) // copy remainder of B into zero-padded block
{
if (blockTempB == null)
blockTempB = AllocBlock(blockSize, blockSize);
strideB = blockSize;
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockTempB[x + blockSize * y] = ((colB + x) < BN && (l + y < BM)) ? blockB[x + BN * y] : 0.0f;
blockB = blockTempB;
}
MultiplyBlockUnrollHx16(blockA, strideA, blockB, strideB, blockS, strideS);
}
if (blockS == blockTempS) // copy back
{
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
{
if (((rowA + y) < SM) && ((colB + x) < SN))
S[(rowA + y) * SN + (colB + x) + batchOffSetS] = blockTempS[x + blockSize * y];
}
}
FreeBlock(blockTempA);
FreeBlock(blockTempB);
FreeBlock(blockTempS);
}
}
static void MultiplyBlockUnrollHx16(float* Ap, int Astride, float* Bp, int Bstride, float* Sp, int Sstride)
{
for (int i = 0; i < blockSize; i++)
{
float sum0 = *(Sp + i * Sstride + 0);
float sum1 = *(Sp + i * Sstride + 1);
float sum2 = *(Sp + i * Sstride + 2);
float sum3 = *(Sp + i * Sstride + 3);
float sum4 = *(Sp + i * Sstride + 4);
float sum5 = *(Sp + i * Sstride + 5);
float sum6 = *(Sp + i * Sstride + 6);
float sum7 = *(Sp + i * Sstride + 7);
float sum8 = *(Sp + i * Sstride + 8);
float sum9 = *(Sp + i * Sstride + 9);
float sumA = *(Sp + i * Sstride + 10);
float sumB = *(Sp + i * Sstride + 11);
float sumC = *(Sp + i * Sstride + 12);
float sumD = *(Sp + i * Sstride + 13);
float sumE = *(Sp + i * Sstride + 14);
float sumF = *(Sp + i * Sstride + 15);
for (int l = 0; l < blockSize; l++)
{
float A = *(Ap + i * Astride + l);
float B0 = *(Bp + l * Bstride + 0);
float B1 = *(Bp + l * Bstride + 1);
float B2 = *(Bp + l * Bstride + 2);
float B3 = *(Bp + l * Bstride + 3);
float B4 = *(Bp + l * Bstride + 4);
float B5 = *(Bp + l * Bstride + 5);
float B6 = *(Bp + l * Bstride + 6);
float B7 = *(Bp + l * Bstride + 7);
float B8 = *(Bp + l * Bstride + 8);
float B9 = *(Bp + l * Bstride + 9);
float BA = *(Bp + l * Bstride + 10);
float BB = *(Bp + l * Bstride + 11);
float BC = *(Bp + l * Bstride + 12);
float BD = *(Bp + l * Bstride + 13);
float BE = *(Bp + l * Bstride + 14);
float BF = *(Bp + l * Bstride + 15);
sum0 += A * B0;
sum1 += A * B1;
sum2 += A * B2;
sum3 += A * B3;
sum4 += A * B4;
sum5 += A * B5;
sum6 += A * B6;
sum7 += A * B7;
sum8 += A * B8;
sum9 += A * B9;
sumA += A * BA;
sumB += A * BB;
sumC += A * BC;
sumD += A * BD;
sumE += A * BE;
sumF += A * BF;
}
*(Sp + i * Sstride + 0 ) = sum0;
*(Sp + i * Sstride + 1 ) = sum1;
*(Sp + i * Sstride + 2 ) = sum2;
*(Sp + i * Sstride + 3 ) = sum3;
*(Sp + i * Sstride + 4 ) = sum4;
*(Sp + i * Sstride + 5 ) = sum5;
*(Sp + i * Sstride + 6 ) = sum6;
*(Sp + i * Sstride + 7 ) = sum7;
*(Sp + i * Sstride + 8 ) = sum8;
*(Sp + i * Sstride + 9 ) = sum9;
*(Sp + i * Sstride + 10) = sumA;
*(Sp + i * Sstride + 11) = sumB;
*(Sp + i * Sstride + 12) = sumC;
*(Sp + i * Sstride + 13) = sumD;
*(Sp + i * Sstride + 14) = sumE;
*(Sp + i * Sstride + 15) = sumF;
}
}
}
[BurstCompile(OptimizeFor = OptimizeFor.Performance, FloatMode = FloatMode.Fast, FloatPrecision = FloatPrecision.Low)]
unsafe struct LSTMDenseJob : IJobParallelFor
{
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* A;
public int AM, AN;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* B;
public int BM, BN;
[NoAlias][NativeDisableUnsafePtrRestriction][ReadOnly] public unsafe float* C;
public int CN;
[NoAlias][NativeDisableUnsafePtrRestriction] public unsafe float* S;
public int SM, SN;
public int dispatchThreadX, dispatchThreadY;
public const int blockSize = 16;
public JobHandle Schedule(JobHandle dependsOn)
{
return Schedule(blocksBatchCount: 1, dependsOn);
}
public JobHandle Schedule(int blocksBatchCount, JobHandle dependsOn)
{
return IJobParallelForExtensions.Schedule(this, dispatchThreadX * dispatchThreadY, blocksBatchCount, dependsOn);
}
public void Execute(int threadID)
{
int i = (threadID % dispatchThreadX);
int j = (threadID / dispatchThreadX);
int rowA = i * blockSize;
int colB = j * blockSize;
unsafe
{
float* blockTempA = null;
float* blockTempB = null;
float* blockTempS = null;
float* blockS = S + rowA * SN + colB;
int strideS = SN;
if (rowA + blockSize > SM || colB + blockSize > SN) // copy remainder of C into zero-padded block
{
blockTempS = AllocBlock(blockSize, blockSize);
strideS = blockSize;
blockS = blockTempS;
}
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockS[x + strideS * y] = (colB + x) < BN ? C[(colB + x)%CN] : 0.0f;
for (int l = 0; l < AN; l += blockSize) // inner-loop
{
float* blockA = A + rowA * AN + l;
float* blockB = B + l * BN + colB;
int strideA = AN;
int strideB = BN;
if (rowA + blockSize > AM || l + blockSize > AN) // copy remainder of A into zero-padded block
{
if (blockTempA == null)
blockTempA = AllocBlock(blockSize, blockSize);
strideA = blockSize;
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockTempA[x + blockSize * y] = ((rowA + y) < AM && (l + x < AN)) ? blockA[x + AN * y] : 0.0f;
blockA = blockTempA;
}
if (colB + blockSize > BN || l + blockSize > BM) // copy remainder of B into zero-padded block
{
if (blockTempB == null)
blockTempB = AllocBlock(blockSize, blockSize);
strideB = blockSize;
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
blockTempB[x + blockSize * y] = ((colB + x) < BN && (l + y < BM)) ? blockB[x + BN * y] : 0.0f;
blockB = blockTempB;
}
MultiplyBlockUnrollHx16(blockA, strideA, blockB, strideB, blockS, strideS);
}
if (blockS == blockTempS) // copy back
{
for (int y = 0; y < blockSize; y++)
for (int x = 0; x < blockSize; x++)
{
if (((rowA + y) < SM) && ((colB + x) < SN))
S[(rowA + y) * SN + (colB + x)] = blockTempS[x + blockSize * y];
}
}
FreeBlock(blockTempA);
FreeBlock(blockTempB);
FreeBlock(blockTempS);
}
}
static void MultiplyBlockUnrollHx16(float* Ap, int Astride, float* Bp, int Bstride, float* Sp, int Sstride)
{
for (int i = 0; i < blockSize; i++)
{
float sum0 = *(Sp + i * Sstride + 0);
float sum1 = *(Sp + i * Sstride + 1);
float sum2 = *(Sp + i * Sstride + 2);
float sum3 = *(Sp + i * Sstride + 3);
float sum4 = *(Sp + i * Sstride + 4);
float sum5 = *(Sp + i * Sstride + 5);
float sum6 = *(Sp + i * Sstride + 6);
float sum7 = *(Sp + i * Sstride + 7);
float sum8 = *(Sp + i * Sstride + 8);
float sum9 = *(Sp + i * Sstride + 9);
float sumA = *(Sp + i * Sstride + 10);
float sumB = *(Sp + i * Sstride + 11);
float sumC = *(Sp + i * Sstride + 12);
float sumD = *(Sp + i * Sstride + 13);
float sumE = *(Sp + i * Sstride + 14);
float sumF = *(Sp + i * Sstride + 15);
for (int l = 0; l < blockSize; l++)
{
float A = *(Ap + i * Astride + l);
float B0 = *(Bp + l * Bstride + 0);
float B1 = *(Bp + l * Bstride + 1);
float B2 = *(Bp + l * Bstride + 2);
float B3 = *(Bp + l * Bstride + 3);
float B4 = *(Bp + l * Bstride + 4);
float B5 = *(Bp + l * Bstride + 5);
float B6 = *(Bp + l * Bstride + 6);
float B7 = *(Bp + l * Bstride + 7);
float B8 = *(Bp + l * Bstride + 8);
float B9 = *(Bp + l * Bstride + 9);
float BA = *(Bp + l * Bstride + 10);
float BB = *(Bp + l * Bstride + 11);
float BC = *(Bp + l * Bstride + 12);
float BD = *(Bp + l * Bstride + 13);
float BE = *(Bp + l * Bstride + 14);
float BF = *(Bp + l * Bstride + 15);
sum0 += A * B0;
sum1 += A * B1;
sum2 += A * B2;
sum3 += A * B3;
sum4 += A * B4;
sum5 += A * B5;
sum6 += A * B6;
sum7 += A * B7;
sum8 += A * B8;
sum9 += A * B9;
sumA += A * BA;
sumB += A * BB;
sumC += A * BC;
sumD += A * BD;
sumE += A * BE;
sumF += A * BF;
}
*(Sp + i * Sstride + 0 ) = sum0;
*(Sp + i * Sstride + 1 ) = sum1;
*(Sp + i * Sstride + 2 ) = sum2;
*(Sp + i * Sstride + 3 ) = sum3;
*(Sp + i * Sstride + 4 ) = sum4;
*(Sp + i * Sstride + 5 ) = sum5;
*(Sp + i * Sstride + 6 ) = sum6;
*(Sp + i * Sstride + 7 ) = sum7;
*(Sp + i * Sstride + 8 ) = sum8;
*(Sp + i * Sstride + 9 ) = sum9;
*(Sp + i * Sstride + 10) = sumA;
*(Sp + i * Sstride + 11) = sumB;
*(Sp + i * Sstride + 12) = sumC;
*(Sp + i * Sstride + 13) = sumD;
*(Sp + i * Sstride + 14) = sumE;
*(Sp + i * Sstride + 15) = sumF;
}
}
}
}
} // namespace Barracuda