bvle-voxels/shaders/voxelAOBlurCS.hlsl

84 lines
2.7 KiB
HLSL
Raw Normal View History

// BVLE Voxels - Bilateral AO Blur Compute Shader (Phase 6.3)
// Separable bilateral blur: preserves edges using depth + normal comparison.
// Run twice: horizontal (direction=0) then vertical (direction=1).
#include "voxelCommon.hlsli"
// Input AO (raw or partially blurred)
Texture2D<float> aoInput : register(t0);
// Depth + normals for edge-stopping
Texture2D<float> depthTexture : register(t1);
Texture2D<float4> normalTexture : register(t2);
// Output blurred AO
RWTexture2D<float> aoOutput : register(u0);
struct BlurPush {
uint width;
uint height;
uint direction; // 0 = horizontal, 1 = vertical
uint radius; // blur kernel radius (e.g. 4 = 9x1 kernel)
float depthThreshold; // edge-stopping depth sensitivity
float normalThreshold; // edge-stopping normal sensitivity (dot product)
uint pad[6];
};
[[vk::push_constant]] ConstantBuffer<BlurPush> push : register(b999);
[RootSignature(VOXEL_ROOTSIG)]
[numthreads(8, 8, 1)]
void main(uint3 DTid : SV_DispatchThreadID) {
if (DTid.x >= push.width || DTid.y >= push.height) return;
float centerAO = aoInput[DTid.xy];
float centerDepth = depthTexture[DTid.xy];
float3 centerN = normalTexture[DTid.xy].xyz;
// Skip sky pixels
if (centerDepth == 0.0) {
aoOutput[DTid.xy] = 1.0;
return;
}
// Gaussian-ish weights (sigma ≈ radius/2)
float totalWeight = 1.0;
float totalAO = centerAO;
int2 step = (push.direction == 0) ? int2(1, 0) : int2(0, 1);
int r = (int)push.radius;
for (int i = -r; i <= r; i++) {
if (i == 0) continue;
int2 coord = int2(DTid.xy) + step * i;
if (coord.x < 0 || coord.x >= (int)push.width ||
coord.y < 0 || coord.y >= (int)push.height)
continue;
float sampleAO = aoInput[coord];
float sampleDepth = depthTexture[coord];
float3 sampleN = normalTexture[coord].xyz;
// Skip sky
if (sampleDepth == 0.0) continue;
// Edge-stopping: depth difference
float depthDiff = abs(centerDepth - sampleDepth);
float depthWeight = exp(-depthDiff * depthDiff / (push.depthThreshold * push.depthThreshold));
// Edge-stopping: normal difference
float normalDot = max(0.0, dot(centerN, sampleN));
float normalWeight = (normalDot > push.normalThreshold) ? normalDot : 0.0;
// Spatial weight (Gaussian falloff)
float dist = float(abs(i)) / float(r + 1);
float spatialWeight = exp(-dist * dist * 2.0);
float w = spatialWeight * depthWeight * normalWeight;
totalWeight += w;
totalAO += sampleAO * w;
}
aoOutput[DTid.xy] = totalAO / totalWeight;
}