Zgragselus

CWBVH

Jul 11th, 2025
32
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.35 KB | None | 0 0
  1. ///////////////////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. // TraversalBvh8.hlsli
  4. //
  5. // Implements ray traversal through multi-level BVH-8 (CWBVH) acceleration structure.
  6. //
  7. ///////////////////////////////////////////////////////////////////////////////////////////////////
  8.  
  9. #ifndef __TRAVERSAL_BVH8__HLSLI__
  10. #define __TRAVERSAL_BVH8__HLSLI__
  11.  
  12. #include "../Raytracer.hlsli"
  13.  
  14. // Definition to compact funciton parameters into TraceRayCompute funciton (BVH-8 variant)
  15. #define TRACE_RAY_PARAMS RWStructuredBuffer<GeometryNode> Geometries,\
  16. RWStructuredBuffer<InstanceNode> Instances,\
  17. RWStructuredBuffer<MemoryNode> ASTreeNodes,\
  18. RWStructuredBuffer<BVH8Node> ASTreeData,\
  19. RWStructuredBuffer<MemoryNode> ASIndexNodes,\
  20. RWStructuredBuffer<uint> ASIndexData,\
  21. RWStructuredBuffer<MemoryNode> WoopNodes,\
  22. RWStructuredBuffer<float4> WoopData
  23.  
  24. // Definition to compact argument passing into TraceRayCompute function (BVH-8 variant)
  25. #define TRACE_RAY_ARGS Geometries,\
  26. Instances,\
  27. ASTreeNodes,\
  28. ASTreeData,\
  29. ASIndexNodes,\
  30. ASIndexData,\
  31. WoopNodes,\
  32. WoopData
  33.  
  34. /// <summary>
  35. /// Get octant of ray direction
  36. /// </summary>
  37. /// <param name="rayDirection">Ray direction</param>
  38. /// <returns>Octant index encoded in 3 bits</returns>
  39. uint GetOctant(float4 rayDirection)
  40. {
  41. // Get inverse of ray octant, encoded in 3 bits
  42. return (rayDirection.x < 0.0f ? 0 : 0x04040404) |
  43. (rayDirection.y < 0.0f ? 0 : 0x02020202) |
  44. (rayDirection.z < 0.0f ? 0 : 0x01010101);
  45. }
  46.  
  47. /// <summary>
  48. /// Extract n-th byte from x
  49. /// </summary>
  50. /// <param name="x">Input value</param>
  51. /// <param name="n">Byte index</param>
  52. /// <returns>N-th byte value of input value</returns>
  53. uint ExtractByte(uint x, uint n)
  54. {
  55. return (x >> (n * 8)) & 0xFF;
  56. }
  57.  
  58. /// <summary>
  59. /// Intersect ray with BVH-8 in compact wide storage
  60. /// </summary>
  61. /// <param name="origin">Ray origin</param>
  62. /// <param name="direction">Ray direction</param>
  63. /// <param name="octantInverse">Inverse of ray octant</param>
  64. /// <param name="maxDistance">Maximum distance to intersect</param>
  65. /// <param name="node0">Holds origin point on local grid in first 12 bytes, exponents for axes in 3 bytes, mask in last byte (determining leaf/interior node)</param>
  66. /// <param name="node1">Holds base child index (4-bytes), base triangle index (4-bytes), meta information (8-bytes)</param>
  67. /// <param name="node2">Holds quantized AABBs - Min X (8-bytes), Max X (8-bytes)</param>
  68. /// <param name="node3">Holds quantized AABBs - Min Y (8-bytes), Max Y (8-bytes)</param>
  69. /// <param name="node4">Holds quantized AABBs - Min Z (8-bytes), Max Z (8-bytes)</param>
  70. /// <returns>Hit mask</returns>
  71. uint IntersectNode(float4 origin, float4 direction, uint octantInverse, float maxDistance, float4 node0, float4 node1, float4 node2, float4 node3, float4 node4)
  72. {
  73. // Get base local point for children
  74. float3 p = node0.xyz;
  75.  
  76. // Get exponents for axes
  77. uint emask = asuint(node0.w);
  78. uint eX = ExtractByte(emask, 0);
  79. uint eY = ExtractByte(emask, 1);
  80. uint eZ = ExtractByte(emask, 2);
  81.  
  82. // Get adjusted direction by axes for intersection
  83. float3 adjDirection = float3(
  84. asfloat(eX << 23) / direction.x,
  85. asfloat(eY << 23) / direction.y,
  86. asfloat(eZ << 23) / direction.z
  87. );
  88.  
  89. // Get adjusted origin for intersection
  90. float3 adjOrigin = (p - origin.xyz) / direction.xyz;
  91.  
  92. // Resulting hitmask
  93. uint hitMask = 0;
  94.  
  95. // Loop through data
  96. [unroll]
  97. for (int i = 0; i < 2; i++)
  98. {
  99. // Meta infromation
  100. uint meta4 = asuint(i == 0 ? node1.z : node1.w);
  101.  
  102. // Extract bit indices and child bits
  103. uint isInner4 = (meta4 & (meta4 << 1)) & 0x10101010;
  104. uint innerMask4 = (((isInner4 << 3) >> 7) & 0x01010101) * 0xff;
  105. uint bitIndex4 = (meta4 ^ (octantInverse & innerMask4)) & 0x1F1F1F1F;
  106. uint childBits4 = (meta4 >> 5) & 0x07070707;
  107.  
  108. // Extract quantized min/max of AABBs
  109. uint qLoX = asuint(i == 0 ? node2.x : node2.y);
  110. uint qHiX = asuint(i == 0 ? node2.z : node2.w);
  111.  
  112. uint qLoY = asuint(i == 0 ? node3.x : node3.y);
  113. uint qHiY = asuint(i == 0 ? node3.z : node3.w);
  114.  
  115. uint qLoZ = asuint(i == 0 ? node4.x : node4.y);
  116. uint qHiZ = asuint(i == 0 ? node4.z : node4.w);
  117.  
  118. // Get per-axis min/max per direction of ray
  119. uint xMin = direction.x < 0.0f ? qHiX : qLoX;
  120. uint xMax = direction.x < 0.0f ? qLoX : qHiX;
  121.  
  122. uint yMin = direction.y < 0.0f ? qHiY : qLoY;
  123. uint yMax = direction.y < 0.0f ? qLoY : qHiY;
  124.  
  125. uint zMin = direction.z < 0.0f ? qHiZ : qLoZ;
  126. uint zMax = direction.z < 0.0f ? qLoZ : qHiZ;
  127.  
  128. // Loop through all 4 AABBs in current iteration (2-iters = 8 AABBs in total)
  129. [unroll]
  130. for (int j = 0; j < 4; j++)
  131. {
  132. // Get quantized min value per axis for given AABB
  133. float3 tmin3 = float3(
  134. float(ExtractByte(xMin, j)),
  135. float(ExtractByte(yMin, j)),
  136. float(ExtractByte(zMin, j)));
  137.  
  138. // Get quantized max value per axis for given AABB
  139. float3 tmax3 = float3(
  140. float(ExtractByte(xMax, j)),
  141. float(ExtractByte(yMax, j)),
  142. float(ExtractByte(zMax, j)));
  143.  
  144. // Use adjusted origin and direction to calculate min/max values
  145. tmin3 = mad(tmin3, adjDirection, adjOrigin);
  146. tmax3 = mad(tmax3, adjDirection, adjOrigin);
  147.  
  148. // Calculate entry and exist distances along ray
  149. float tmin = max(max(tmin3.x, tmin3.y), max(tmin3.z, 0.1f));
  150. float tmax = min(min(tmax3.x, tmax3.y), min(tmax3.z, maxDistance));
  151.  
  152. // Check whether intersection happens
  153. bool intersection = tmin <= tmax;
  154.  
  155. // In case of intersection, store in hitmask
  156. [branch]
  157. if (intersection)
  158. {
  159. uint childBits = ExtractByte(childBits4, j);
  160. uint bitIndex = ExtractByte(bitIndex4, j);
  161.  
  162. hitMask |= childBits << bitIndex;
  163. }
  164. }
  165. }
  166.  
  167. return hitMask;
  168. }
  169.  
  170. /// <summary>
  171. /// Performs ray traversal through acceleration structure for single ray.
  172. ///
  173. /// This function performs traversal through compressed wide Bounding Volume Hierarchy
  174. /// (BVH-8/CWBVH). Result of this funciton is represented by barycentric coordinates, primitive ID,
  175. /// geometry ID and distance along the ray to hitpoint.
  176. /// </summary>
  177. /// <param name="r">Ray to trace.</param>
  178. /// <param name="Geometries">Buffer of GeometryNode - holds all definition for geometries in the scene</param>
  179. /// <param name="Instances">Buffer of InstanceNode - holds all geometry instances definitions in the scene</param>
  180. /// <param name="ASTreeNodes">Buffer of memory nodes - each representing single BVH node data definition/offsets</param>
  181. /// <param name="ASTreeData">Buffer of BVH nodes - BVH node data</param>
  182. /// <param name="ASIndexNodes">Buffer of memory nodes - each representing single BVH index data definition/offsets</param>
  183. /// <param name="ASIndexData">Buffer of BVH indices - BVH index data</param>
  184. /// <param name="WoopNodes">Buffer of memory nodes - each representing definition/offsets into data buffer holding woop triangle data</param>
  185. /// <param name="WoopData">Buffer of woop triangle data - Woop triangle geometry data</param>
  186. /// <returns>
  187. /// 4-component vector, where:
  188. /// 1st component has packed U/V barycentric coordinates (see PackFloat2/UnpackFloat2)
  189. /// 2nd component distance along the ray to hit
  190. /// 3rd component primitive ID (unsigned int)
  191. /// 4th component geometry ID (unsigned int)
  192. /// </returns>
  193. float4 TraceRayCompute(Ray r, TRACE_RAY_PARAMS)
  194. {
  195. float4 o = r.Origin;
  196. float4 d = r.Direction;
  197. float4 inv = r.Inverse;
  198. uint octInv4 = GetOctant(d);
  199.  
  200. uint2 currentGroup = uint2(0, 0x80000000);
  201. uint2 triangleGroup = uint2(0, 0);
  202.  
  203. uint2 stack[BVH_STACK_SIZE];
  204. uint stack_ptr = 0;
  205.  
  206. int meshbvh_stack_ptr = -1;
  207.  
  208. float tmin = 0.0f;
  209. float tmax = 10000.0f;
  210. float bU = 0.0f;
  211. float bV = 0.0f;
  212. uint prim_id = 0;
  213. uint geometryID = 0;
  214. bool hit = false;
  215.  
  216. InstanceNode instance = Instances[0];
  217.  
  218. // Traversal (use for for testing)
  219. [loop]
  220. for (int i = 0; i < 1024; i++)
  221. {
  222. // Test whether we're in interior node
  223. [branch]
  224. if (currentGroup.y & 0xff000000)
  225. {
  226. // Get next child index (consume bit)
  227. uint childIndexOffset = firstbithigh(currentGroup.y);
  228.  
  229. uint slotIndex = (childIndexOffset - 24) ^ (octInv4 & 0xff);
  230. uint relativeIndex = countbits(currentGroup.y & ~(0xffffffff << slotIndex));
  231. uint childNodeIndex = currentGroup.x + relativeIndex;
  232.  
  233. currentGroup.y &= ~(1 << childIndexOffset);
  234.  
  235. if (currentGroup.y & 0xff000000)
  236. {
  237. stack[stack_ptr] = currentGroup;
  238. stack_ptr++;
  239. }
  240.  
  241. // Perform intersection test against all 8 children
  242. uint hitMask = IntersectNode(o,
  243. d,
  244. octInv4,
  245. tmax,
  246. ASTreeData[childNodeIndex].Node0,
  247. ASTreeData[childNodeIndex].Node1,
  248. ASTreeData[childNodeIndex].Node2,
  249. ASTreeData[childNodeIndex].Node3,
  250. ASTreeData[childNodeIndex].Node4);
  251.  
  252. // Update masks from hit results
  253. currentGroup.y = (hitMask & 0xff000000) | ((asuint(ASTreeData[childNodeIndex].Node0.w) >> 24) & 0xff);
  254. triangleGroup.y = (hitMask & 0x00ffffff);
  255.  
  256. currentGroup.x = asuint(ASTreeData[childNodeIndex].Node1.x);
  257. triangleGroup.x = asuint(ASTreeData[childNodeIndex].Node1.y);
  258. }
  259. else
  260. {
  261. // Leaf node - current node group holds triangle group
  262. triangleGroup = currentGroup;
  263. currentGroup = uint2(0, 0);
  264. }
  265.  
  266. // We are in leaf node
  267. if (triangleGroup.y != 0)
  268. {
  269. // We're searching top-level BVH (TLAS), enter bottom-level BVH (BLAS)
  270. if (meshbvh_stack_ptr == -1)
  271. {
  272. uint blas_offset = firstbithigh(triangleGroup.y);
  273. triangleGroup.y &= ~(1 << blas_offset);
  274. uint index_offset = ASIndexNodes[triangleGroup.x + blas_offset].Offset / 4;
  275. uint instance_index = ASIndexData[triangleGroup.x + blas_offset];
  276. instance = Instances[instance_index];
  277.  
  278. if (triangleGroup.y != 0)
  279. {
  280. stack[stack_ptr] = triangleGroup;
  281. stack_ptr++;
  282. }
  283.  
  284. if (currentGroup.y & 0xff000000)
  285. {
  286. stack[stack_ptr] = currentGroup;
  287. stack_ptr++;
  288. }
  289.  
  290. meshbvh_stack_ptr = stack_ptr;
  291.  
  292. o = mul(r.Origin, instance.TransformInverse);
  293. d = mul(r.Direction, instance.TransformInverse);
  294. inv = rcp(d);
  295. octInv4 = GetOctant(d);
  296.  
  297. currentGroup.x = ASTreeNodes[Geometries[instance.GeometryNode].BVHNode + 1].Offset / 80;
  298. currentGroup.y = 0x80000000;
  299. }
  300. // We're already in bottom-level BVH (BLAS)
  301. else
  302. {
  303. while (triangleGroup.y != 0)
  304. {
  305. // Obtain next triangle from triangle group in BLAS node record
  306. uint triangleIndex = firstbithigh(triangleGroup.y);
  307. triangleGroup.y &= ~(1 << triangleIndex);
  308.  
  309. GeometryNode geom = Geometries[instance.GeometryNode];
  310. MemoryNode wbo = WoopNodes[geom.WoopBufferNode];
  311.  
  312. uint index_offset = ASIndexNodes[triangleGroup.x + triangleIndex].Offset / 4;
  313.  
  314. // Don't trash cache by reading index through it
  315. uint tri_idx = ASIndexData[triangleGroup.x + triangleIndex] * 3;
  316.  
  317. // Fetch data for Woop's triangle
  318. float4 r = WoopData[wbo.Offset / 16 + tri_idx + 0];
  319. float4 p = WoopData[wbo.Offset / 16 + tri_idx + 1];
  320. float4 q = WoopData[wbo.Offset / 16 + tri_idx + 2];
  321.  
  322. // Perform intersection
  323. float o_z = r.w - o.x * r.x - o.y * r.y - o.z * r.z;
  324. float i_z = 1.0f / (d.x * r.x + d.y * r.y + d.z * r.z);
  325. float t = o_z * i_z;
  326.  
  327. if (t > tmin && t < tmax)
  328. {
  329. float o_x = p.w + o.x * p.x + o.y * p.y + o.z * p.z;
  330. float d_x = d.x * p.x + d.y * p.y + d.z * p.z;
  331. float u = o_x + t * d_x;
  332.  
  333. if (u >= 0.0f && u <= 1.0f)
  334. {
  335. float o_y = q.w + o.x * q.x + o.y * q.y + o.z * q.z;
  336. float d_y = d.x * q.x + d.y * q.y + d.z * q.z;
  337. float v = o_y + t * d_y;
  338.  
  339. if (v >= 0.0f && u + v <= 1.0f)
  340. {
  341. tmax = t;
  342. bU = u;
  343. bV = v;
  344. hit = true;
  345.  
  346. geometryID = instance.GeometryNode;
  347. prim_id = tri_idx;
  348. }
  349. }
  350. }
  351. }
  352. }
  353. }
  354.  
  355. // Pop stack if any item still in it, end traversal otherwise
  356. if ((currentGroup.y & 0xff000000) == 0)
  357. {
  358. // Entrypoint has been reached, terminate traversal
  359. if (stack_ptr == 0)
  360. {
  361. break;
  362. }
  363.  
  364. // If we're in BLAS and we're on entrypoint, then reset the ray as the traversal will
  365. // continue in TLAS
  366. if (stack_ptr == meshbvh_stack_ptr)
  367. {
  368. meshbvh_stack_ptr = -1;
  369.  
  370. o = r.Origin;
  371. d = r.Direction;
  372. inv = r.Inverse;
  373. octInv4 = GetOctant(d);
  374. }
  375.  
  376. // Pop from stack
  377. stack_ptr--;
  378. currentGroup = stack[stack_ptr];
  379. }
  380. }
  381.  
  382. return float4(PackFloat2(bU, bV), tmax, asfloat(prim_id), asfloat(geometryID));
  383. }
  384.  
  385. #endif
Add Comment
Please, Sign In to add comment