// BVLE Voxels - Bilateral AO Blur Compute Shader (Phase 6.3) // Separable bilateral blur: preserves edges using depth + normal comparison. // Run twice: horizontal (direction=0) then vertical (direction=1). #include "voxelCommon.hlsli" // Input AO (raw or partially blurred) Texture2D aoInput : register(t0); // Depth + normals for edge-stopping Texture2D depthTexture : register(t1); Texture2D normalTexture : register(t2); // Output blurred AO RWTexture2D aoOutput : register(u0); struct BlurPush { uint width; uint height; uint direction; // 0 = horizontal, 1 = vertical uint radius; // blur kernel radius (e.g. 4 = 9x1 kernel) float depthThreshold; // edge-stopping depth sensitivity float normalThreshold; // edge-stopping normal sensitivity (dot product) uint pad[6]; }; [[vk::push_constant]] ConstantBuffer push : register(b999); [RootSignature(VOXEL_ROOTSIG)] [numthreads(8, 8, 1)] void main(uint3 DTid : SV_DispatchThreadID) { if (DTid.x >= push.width || DTid.y >= push.height) return; float centerAO = aoInput[DTid.xy]; float centerDepth = depthTexture[DTid.xy]; float3 centerN = normalTexture[DTid.xy].xyz; // Skip sky pixels if (centerDepth == 0.0) { aoOutput[DTid.xy] = 1.0; return; } // Gaussian-ish weights (sigma ≈ radius/2) float totalWeight = 1.0; float totalAO = centerAO; int2 step = (push.direction == 0) ? int2(1, 0) : int2(0, 1); int r = (int)push.radius; for (int i = -r; i <= r; i++) { if (i == 0) continue; int2 coord = int2(DTid.xy) + step * i; if (coord.x < 0 || coord.x >= (int)push.width || coord.y < 0 || coord.y >= (int)push.height) continue; float sampleAO = aoInput[coord]; float sampleDepth = depthTexture[coord]; float3 sampleN = normalTexture[coord].xyz; // Skip sky if (sampleDepth == 0.0) continue; // Edge-stopping: depth difference float depthDiff = abs(centerDepth - sampleDepth); float depthWeight = exp(-depthDiff * depthDiff / (push.depthThreshold * push.depthThreshold)); // Edge-stopping: normal difference float normalDot = max(0.0, dot(centerN, sampleN)); float normalWeight = (normalDot > push.normalThreshold) ? normalDot : 0.0; // Spatial weight (Gaussian falloff) float dist = float(abs(i)) / float(r + 1); float spatialWeight = exp(-dist * dist * 2.0); float w = spatialWeight * depthWeight * normalWeight; totalWeight += w; totalAO += sampleAO * w; } aoOutput[DTid.xy] = totalAO / totalWeight; }