Advertisement
Zgragselus

CWBVH 8-wide traversal

May 24th, 2024
610
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 15.35 KB | None | 0 0
  1. ///////////////////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. // TraversalBvh8.hlsli
  4. //
  5. // Implements ray traversal through multi-level BVH-8 (CWBVH) acceleration structure.
  6. //
  7. ///////////////////////////////////////////////////////////////////////////////////////////////////
  8.  
  9. #ifndef __TRAVERSAL_BVH8__HLSLI__
  10. #define __TRAVERSAL_BVH8__HLSLI__
  11.  
  12. #include "../Raytracer.hlsli"
  13.  
  14. // Definition to compact funciton parameters into TraceRayCompute funciton (BVH-8 variant)
  15. #define TRACE_RAY_PARAMS RWStructuredBuffer<GeometryNode> Geometries,\
  16.                          RWStructuredBuffer<InstanceNode> Instances,\
  17.                          RWStructuredBuffer<MemoryNode> ASTreeNodes,\
  18.                          RWStructuredBuffer<BVH8Node> ASTreeData,\
  19.                          RWStructuredBuffer<MemoryNode> ASIndexNodes,\
  20.                          RWStructuredBuffer<uint> ASIndexData,\
  21.                          RWStructuredBuffer<MemoryNode> WoopNodes,\
  22.                          RWStructuredBuffer<float4> WoopData
  23.    
  24. // Definition to compact argument passing into TraceRayCompute function (BVH-8 variant)
  25. #define TRACE_RAY_ARGS  Geometries,\
  26.                         Instances,\
  27.                         ASTreeNodes,\
  28.                         ASTreeData,\
  29.                         ASIndexNodes,\
  30.                         ASIndexData,\
  31.                         WoopNodes,\
  32.                         WoopData
  33.  
  34. /// <summary>
  35. /// Get octant of ray direction
  36. /// </summary>
  37. /// <param name="rayDirection">Ray direction</param>
  38. /// <returns>Octant index encoded in 3 bits</returns>
  39. uint GetOctant(float4 rayDirection)
  40. {
  41.     // Get inverse of ray octant, encoded in 3 bits
  42.     return (rayDirection.x < 0.0f ? 0 : 0x04040404) |
  43.         (rayDirection.y < 0.0f ? 0 : 0x02020202) |
  44.         (rayDirection.z < 0.0f ? 0 : 0x01010101);
  45. }
  46.  
  47. /// <summary>
  48. /// Extract n-th byte from x
  49. /// </summary>
  50. /// <param name="x">Input value</param>
  51. /// <param name="n">Byte index</param>
  52. /// <returns>N-th byte value of input value</returns>
  53. uint ExtractByte(uint x, uint n)
  54. {
  55.     return (x >> (n * 8)) & 0xFF;
  56. }
  57.  
  58. /// <summary>
  59. /// Intersect ray with BVH-8 in compact wide storage
  60. /// </summary>
  61. /// <param name="origin">Ray origin</param>
  62. /// <param name="direction">Ray direction</param>
  63. /// <param name="octantInverse">Inverse of ray octant</param>
  64. /// <param name="maxDistance">Maximum distance to intersect</param>
  65. /// <param name="node0">Holds origin point on local grid in first 12 bytes, exponents for axes in 3 bytes, mask in last byte (determining leaf/interior node)</param>
  66. /// <param name="node1">Holds base child index (4-bytes), base triangle index (4-bytes), meta information (8-bytes)</param>
  67. /// <param name="node2">Holds quantized AABBs - Min X (8-bytes), Max X (8-bytes)</param>
  68. /// <param name="node3">Holds quantized AABBs - Min Y (8-bytes), Max Y (8-bytes)</param>
  69. /// <param name="node4">Holds quantized AABBs - Min Z (8-bytes), Max Z (8-bytes)</param>
  70. /// <returns>Hit mask</returns>
  71. uint IntersectNode(float4 origin, float4 direction, uint octantInverse, float maxDistance, float4 node0, float4 node1, float4 node2, float4 node3, float4 node4)
  72. {
  73.     // Get base local point for children
  74.     float3 p = node0.xyz;
  75.    
  76.     // Get exponents for axes
  77.     uint emask = asuint(node0.w);
  78.     uint eX = ExtractByte(emask, 0);
  79.     uint eY = ExtractByte(emask, 1);
  80.     uint eZ = ExtractByte(emask, 2);
  81.    
  82.     // Get adjusted direction by axes for intersection
  83.     float3 adjDirection = float3(
  84.         asfloat(eX << 23) / direction.x,
  85.         asfloat(eY << 23) / direction.y,
  86.         asfloat(eZ << 23) / direction.z
  87.     );
  88.    
  89.     // Get adjusted origin for intersection
  90.     float3 adjOrigin = (p - origin.xyz) / direction.xyz;
  91.    
  92.     // Resulting hitmask
  93.     uint hitMask = 0;
  94.    
  95.     // Loop through data
  96.     [unroll]
  97.     for (int i = 0; i < 2; i++)
  98.     {
  99.         // Meta infromation
  100.         uint meta4 = asuint(i == 0 ? node1.z : node1.w);
  101.        
  102.         // Extract bit indices and child bits
  103.         uint isInner4 = (meta4 & (meta4 << 1)) & 0x10101010;
  104.         uint innerMask4 = (((isInner4 << 3) >> 7) & 0x01010101) * 0xff;
  105.         uint bitIndex4 = (meta4 ^ (octantInverse & innerMask4)) & 0x1F1F1F1F;
  106.         uint childBits4 = (meta4 >> 5) & 0x07070707;
  107.        
  108.         // Extract quantized min/max of AABBs
  109.         uint qLoX = asuint(i == 0 ? node2.x : node2.y);
  110.         uint qHiX = asuint(i == 0 ? node2.z : node2.w);
  111.        
  112.         uint qLoY = asuint(i == 0 ? node3.x : node3.y);
  113.         uint qHiY = asuint(i == 0 ? node3.z : node3.w);
  114.        
  115.         uint qLoZ = asuint(i == 0 ? node4.x : node4.y);
  116.         uint qHiZ = asuint(i == 0 ? node4.z : node4.w);
  117.        
  118.         // Get per-axis min/max per direction of ray
  119.         uint xMin = direction.x < 0.0f ? qHiX : qLoX;
  120.         uint xMax = direction.x < 0.0f ? qLoX : qHiX;
  121.        
  122.         uint yMin = direction.y < 0.0f ? qHiY : qLoY;
  123.         uint yMax = direction.y < 0.0f ? qLoY : qHiY;
  124.        
  125.         uint zMin = direction.z < 0.0f ? qHiZ : qLoZ;
  126.         uint zMax = direction.z < 0.0f ? qLoZ : qHiZ;
  127.        
  128.         // Loop through all 4 AABBs in current iteration (2-iters = 8 AABBs in total)
  129.         [unroll]
  130.         for (int j = 0; j < 4; j++)
  131.         {
  132.             // Get quantized min value per axis for given AABB
  133.             float3 tmin3 = float3(
  134.                 float(ExtractByte(xMin, j)),
  135.                 float(ExtractByte(yMin, j)),
  136.                 float(ExtractByte(zMin, j)));
  137.  
  138.             // Get quantized max value per axis for given AABB
  139.             float3 tmax3 = float3(
  140.                 float(ExtractByte(xMax, j)),
  141.                 float(ExtractByte(yMax, j)),
  142.                 float(ExtractByte(zMax, j)));
  143.            
  144.             // Use adjusted origin and direction to calculate min/max values
  145.             tmin3 = mad(tmin3, adjDirection, adjOrigin);
  146.             tmax3 = mad(tmax3, adjDirection, adjOrigin);
  147.            
  148.             // Calculate entry and exist distances along ray
  149.             float tmin = max(max(tmin3.x, tmin3.y), max(tmin3.z, 0.1f));
  150.             float tmax = min(min(tmax3.x, tmax3.y), min(tmax3.z, maxDistance));
  151.            
  152.             // Check whether intersection happens
  153.             bool intersection = tmin <= tmax;
  154.            
  155.             // In case of intersection, store in hitmask
  156.             [branch]
  157.             if (intersection)
  158.             {
  159.                 uint childBits = ExtractByte(childBits4, j);
  160.                 uint bitIndex = ExtractByte(bitIndex4, j);
  161.                
  162.                 hitMask |= childBits << bitIndex;
  163.             }
  164.         }
  165.     }
  166.  
  167.     return hitMask;
  168. }
  169.  
  170. /// <summary>
  171. /// Performs ray traversal through acceleration structure for single ray.
  172. ///
  173. /// This function performs traversal through compressed wide Bounding Volume Hierarchy
  174. /// (BVH-8/CWBVH). Result of this funciton is represented by barycentric coordinates, primitive ID,
  175. /// geometry ID and distance along the ray to hitpoint.
  176. /// </summary>
  177. /// <param name="r">Ray to trace.</param>
  178. /// <param name="Geometries">Buffer of GeometryNode - holds all definition for geometries in the scene</param>
  179. /// <param name="Instances">Buffer of InstanceNode - holds all geometry instances definitions in the scene</param>
  180. /// <param name="ASTreeNodes">Buffer of memory nodes - each representing single BVH node data definition/offsets</param>
  181. /// <param name="ASTreeData">Buffer of BVH nodes - BVH node data</param>
  182. /// <param name="ASIndexNodes">Buffer of memory nodes - each representing single BVH index data definition/offsets</param>
  183. /// <param name="ASIndexData">Buffer of BVH indices - BVH index data</param>
  184. /// <param name="WoopNodes">Buffer of memory nodes - each representing definition/offsets into data buffer holding woop triangle data</param>
  185. /// <param name="WoopData">Buffer of woop triangle data - Woop triangle geometry data</param>
  186. /// <returns>
  187. /// 4-component vector, where:
  188. ///   1st component has packed U/V barycentric coordinates (see PackFloat2/UnpackFloat2)
  189. ///   2nd component distance along the ray to hit
  190. ///   3rd component primitive ID (unsigned int)
  191. ///   4th component geometry ID (unsigned int)
  192. /// </returns>
  193. float4 TraceRayCompute(Ray r, TRACE_RAY_PARAMS)
  194. {
  195.     float4 o = r.Origin;
  196.     float4 d = r.Direction;
  197.     float4 inv = r.Inverse;
  198.     uint octInv4 = GetOctant(d);
  199.    
  200.     uint2 currentGroup = uint2(0, 0x80000000);
  201.     uint2 triangleGroup = uint2(0, 0);
  202.        
  203.     uint2 stack[BVH_STACK_SIZE];
  204.     uint stack_ptr = 0;
  205.  
  206.     int meshbvh_stack_ptr = -1;
  207.    
  208.     float tmin = 0.0f;
  209.     float tmax = 10000.0f;
  210.     float bU = 0.0f;
  211.     float bV = 0.0f;
  212.     uint prim_id = 0;
  213.     uint geometryID = 0;
  214.     bool hit = false;
  215.    
  216.     InstanceNode instance = Instances[0];
  217.    
  218.     // Traversal (use for for testing)
  219.     [loop]
  220.     for (int i = 0; i < 1024; i++)
  221.     {
  222.         // Test whether we're in interior node
  223.         [branch]
  224.         if (currentGroup.y & 0xff000000)
  225.         {
  226.             // Get next child index (consume bit)
  227.             uint childIndexOffset = firstbithigh(currentGroup.y);
  228.            
  229.             uint slotIndex = (childIndexOffset - 24) ^ (octInv4 & 0xff);
  230.             uint relativeIndex = countbits(currentGroup.y & ~(0xffffffff << slotIndex));
  231.             uint childNodeIndex = currentGroup.x + relativeIndex;
  232.            
  233.             currentGroup.y &= ~(1 << childIndexOffset);
  234.            
  235.             if (currentGroup.y & 0xff000000)
  236.             {
  237.                 stack[stack_ptr] = currentGroup;
  238.                 stack_ptr++;
  239.             }
  240.            
  241.             // Perform intersection test against all 8 children
  242.             uint hitMask = IntersectNode(o,
  243.                 d,
  244.                 octInv4,
  245.                 tmax,
  246.                 ASTreeData[childNodeIndex].Node0,
  247.                 ASTreeData[childNodeIndex].Node1,
  248.                 ASTreeData[childNodeIndex].Node2,
  249.                 ASTreeData[childNodeIndex].Node3,
  250.                 ASTreeData[childNodeIndex].Node4);
  251.            
  252.             // Update masks from hit results
  253.             currentGroup.y = (hitMask & 0xff000000) | ((asuint(ASTreeData[childNodeIndex].Node0.w) >> 24) & 0xff);
  254.             triangleGroup.y = (hitMask & 0x00ffffff);
  255.            
  256.             currentGroup.x = asuint(ASTreeData[childNodeIndex].Node1.x);
  257.             triangleGroup.x = asuint(ASTreeData[childNodeIndex].Node1.y);
  258.         }
  259.         else
  260.         {
  261.             // Leaf node - current node group holds triangle group
  262.             triangleGroup = currentGroup;
  263.             currentGroup = uint2(0, 0);
  264.         }
  265.        
  266.         // We are in leaf node
  267.         if (triangleGroup.y != 0)
  268.         {
  269.             // We're searching top-level BVH (TLAS), enter bottom-level BVH (BLAS)
  270.             if (meshbvh_stack_ptr == -1)
  271.             {
  272.                 uint blas_offset = firstbithigh(triangleGroup.y);
  273.                 triangleGroup.y &= ~(1 << blas_offset);
  274.                 uint index_offset = ASIndexNodes[triangleGroup.x + blas_offset].Offset / 4;
  275.                 uint instance_index = ASIndexData[triangleGroup.x + blas_offset];
  276.                 instance = Instances[instance_index];
  277.                                
  278.                 if (triangleGroup.y != 0)
  279.                 {
  280.                     stack[stack_ptr] = triangleGroup;
  281.                     stack_ptr++;
  282.                 }
  283.                
  284.                 if (currentGroup.y & 0xff000000)
  285.                 {
  286.                     stack[stack_ptr] = currentGroup;
  287.                     stack_ptr++;
  288.                 }
  289.                
  290.                 meshbvh_stack_ptr = stack_ptr;
  291.            
  292.                 o = mul(r.Origin, instance.TransformInverse);
  293.                 d = mul(r.Direction, instance.TransformInverse);
  294.                 inv = rcp(d);
  295.                 octInv4 = GetOctant(d);
  296.                
  297.                 currentGroup.x = ASTreeNodes[Geometries[instance.GeometryNode].BVHNode + 1].Offset / 80;
  298.                 currentGroup.y = 0x80000000;
  299.             }
  300.             // We're already in bottom-level BVH (BLAS)
  301.             else
  302.             {
  303.                 while (triangleGroup.y != 0)
  304.                 {
  305.                     // Obtain next triangle from triangle group in BLAS node record
  306.                     uint triangleIndex = firstbithigh(triangleGroup.y);
  307.                     triangleGroup.y &= ~(1 << triangleIndex);
  308.            
  309.                     GeometryNode geom = Geometries[instance.GeometryNode];
  310.                     MemoryNode wbo = WoopNodes[geom.WoopBufferNode];
  311.                
  312.                     uint index_offset = ASIndexNodes[triangleGroup.x + triangleIndex].Offset / 4;
  313.              
  314.                     // Don't trash cache by reading index through it
  315.                     uint tri_idx = ASIndexData[triangleGroup.x + triangleIndex] * 3;
  316.                    
  317.                     // Fetch data for Woop's triangle
  318.                     float4 r = WoopData[wbo.Offset / 16 + tri_idx + 0];
  319.                     float4 p = WoopData[wbo.Offset / 16 + tri_idx + 1];
  320.                     float4 q = WoopData[wbo.Offset / 16 + tri_idx + 2];
  321.                    
  322.                     // Perform intersection
  323.                     float o_z = r.w - o.x * r.x - o.y * r.y - o.z * r.z;
  324.                     float i_z = 1.0f / (d.x * r.x + d.y * r.y + d.z * r.z);
  325.                     float t = o_z * i_z;
  326.                    
  327.                     if (t > tmin && t < tmax)
  328.                     {
  329.                         float o_x = p.w + o.x * p.x + o.y * p.y + o.z * p.z;
  330.                         float d_x = d.x * p.x + d.y * p.y + d.z * p.z;
  331.                         float u = o_x + t * d_x;
  332.  
  333.                         if (u >= 0.0f && u <= 1.0f)
  334.                         {
  335.                             float o_y = q.w + o.x * q.x + o.y * q.y + o.z * q.z;
  336.                             float d_y = d.x * q.x + d.y * q.y + d.z * q.z;
  337.                             float v = o_y + t * d_y;
  338.                            
  339.                             if (v >= 0.0f && u + v <= 1.0f)
  340.                             {
  341.                                 tmax = t;
  342.                                 bU = u;
  343.                                 bV = v;
  344.                                 hit = true;
  345.  
  346.                                 geometryID = instance.GeometryNode;
  347.                                 prim_id = tri_idx;
  348.                             }
  349.                         }
  350.                     }
  351.                 }
  352.             }
  353.         }
  354.        
  355.         // Pop stack if any item still in it, end traversal otherwise
  356.         if ((currentGroup.y & 0xff000000) == 0)
  357.         {
  358.             // Entrypoint has been reached, terminate traversal
  359.             if (stack_ptr == 0)
  360.             {
  361.                 break;
  362.             }
  363.            
  364.             // If we're in BLAS and we're on entrypoint, then reset the ray as the traversal will
  365.             // continue in TLAS
  366.             if (stack_ptr == meshbvh_stack_ptr)
  367.             {
  368.                 meshbvh_stack_ptr = -1;
  369.                
  370.                 o = r.Origin;
  371.                 d = r.Direction;
  372.                 inv = r.Inverse;
  373.                 octInv4 = GetOctant(d);
  374.             }
  375.            
  376.             // Pop from stack
  377.             stack_ptr--;
  378.             currentGroup = stack[stack_ptr];
  379.         }
  380.     }
  381.    
  382.     return float4(PackFloat2(bU, bV), tmax, asfloat(prim_id), asfloat(geometryID));
  383. }
  384.  
  385. #endif
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement