79 lines
2.1 KiB
Plaintext
79 lines
2.1 KiB
Plaintext
#pragma kernel ScanInBucketExclusive
|
|
#pragma kernel ScanAddBucketResult
|
|
|
|
#define THREADS_PER_GROUP 512 // Ensure that this equals the 'threadsPerGroup' const in the host script. Must be an odd power of 2.
|
|
|
|
// These must be a multiple of THREADS_PER_GROUP.
|
|
StructuredBuffer<uint> _Input;
|
|
RWStructuredBuffer<uint> _Result;
|
|
RWStructuredBuffer<uint> _BlockSum;
|
|
uint count;
|
|
|
|
groupshared uint bucket[THREADS_PER_GROUP];
|
|
|
|
// Scan in each bucket.
|
|
[numthreads(THREADS_PER_GROUP, 1, 1)]
|
|
void ScanInBucketExclusive(uint DTid : SV_DispatchThreadID, uint Gid : SV_GroupID, uint GI : SV_GroupIndex)
|
|
{
|
|
if (DTid < count) {
|
|
bucket[GI] = _Input[DTid];
|
|
} else {
|
|
bucket[GI] = 0;
|
|
}
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
uint stride;
|
|
|
|
// up-sweep
|
|
[unroll]
|
|
for (stride = 2; stride <= THREADS_PER_GROUP; stride <<= 1)
|
|
{
|
|
GroupMemoryBarrierWithGroupSync();
|
|
if (((GI + 1) % stride) == 0)
|
|
{
|
|
const uint half_stride = (stride >> 1);
|
|
bucket[GI] += bucket[GI - half_stride];
|
|
}
|
|
}
|
|
|
|
// Without this barrier, setting tg_mem[-1] to 0 may not be properly
|
|
// propagated across the entire threadgroup.
|
|
GroupMemoryBarrierWithGroupSync();
|
|
if (GI == THREADS_PER_GROUP - 1)
|
|
{
|
|
// clear the last element
|
|
_BlockSum[Gid] = bucket[GI];
|
|
bucket[GI] = 0;
|
|
}
|
|
|
|
// down-sweep
|
|
[unroll]
|
|
for (stride = THREADS_PER_GROUP; stride > 1; stride >>= 1)
|
|
{
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
if (((GI + 1) % stride) == 0)
|
|
{
|
|
const uint half_stride = (stride >> 1);
|
|
const uint prev_idx = GI - half_stride;
|
|
const int tmp = bucket[prev_idx];
|
|
bucket[prev_idx] = bucket[GI];
|
|
bucket[GI] += tmp;
|
|
}
|
|
}
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
if (DTid < count)
|
|
_Result[DTid] = bucket[GI];
|
|
}
|
|
|
|
// Add the bucket scanned result to each bucket to get the final result.
|
|
[numthreads(THREADS_PER_GROUP, 1, 1)]
|
|
void ScanAddBucketResult(uint DTid : SV_DispatchThreadID, uint Gid : SV_GroupID)
|
|
{
|
|
if (DTid < count)
|
|
_Result[DTid] += _Input[Gid];
|
|
}
|