_xiaofang/xiaofang/Assets/Obi/Resources/Compute/Scan.compute

79 lines
2.1 KiB
Plaintext
Raw Normal View History

2024-12-18 02:18:45 +08:00
#pragma kernel ScanInBucketExclusive
#pragma kernel ScanAddBucketResult
#define THREADS_PER_GROUP 512 // Ensure that this equals the 'threadsPerGroup' const in the host script. Must be an odd power of 2.
// These must be a multiple of THREADS_PER_GROUP.
StructuredBuffer<uint> _Input;
RWStructuredBuffer<uint> _Result;
RWStructuredBuffer<uint> _BlockSum;
uint count;
groupshared uint bucket[THREADS_PER_GROUP];
// Scan in each bucket.
[numthreads(THREADS_PER_GROUP, 1, 1)]
void ScanInBucketExclusive(uint DTid : SV_DispatchThreadID, uint Gid : SV_GroupID, uint GI : SV_GroupIndex)
{
if (DTid < count) {
bucket[GI] = _Input[DTid];
} else {
bucket[GI] = 0;
}
GroupMemoryBarrierWithGroupSync();
uint stride;
// up-sweep
[unroll]
for (stride = 2; stride <= THREADS_PER_GROUP; stride <<= 1)
{
GroupMemoryBarrierWithGroupSync();
if (((GI + 1) % stride) == 0)
{
const uint half_stride = (stride >> 1);
bucket[GI] += bucket[GI - half_stride];
}
}
// Without this barrier, setting tg_mem[-1] to 0 may not be properly
// propagated across the entire threadgroup.
GroupMemoryBarrierWithGroupSync();
if (GI == THREADS_PER_GROUP - 1)
{
// clear the last element
_BlockSum[Gid] = bucket[GI];
bucket[GI] = 0;
}
// down-sweep
[unroll]
for (stride = THREADS_PER_GROUP; stride > 1; stride >>= 1)
{
GroupMemoryBarrierWithGroupSync();
if (((GI + 1) % stride) == 0)
{
const uint half_stride = (stride >> 1);
const uint prev_idx = GI - half_stride;
const int tmp = bucket[prev_idx];
bucket[prev_idx] = bucket[GI];
bucket[GI] += tmp;
}
}
GroupMemoryBarrierWithGroupSync();
if (DTid < count)
_Result[DTid] = bucket[GI];
}
// Add the bucket scanned result to each bucket to get the final result.
[numthreads(THREADS_PER_GROUP, 1, 1)]
void ScanAddBucketResult(uint DTid : SV_DispatchThreadID, uint Gid : SV_GroupID)
{
if (DTid < count)
_Result[DTid] += _Input[Gid];
}