Guest User

Untitled

a guest
Nov 4th, 2025
22
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.98 KB | None | 0 0
  1. import torch
  2. from functools import partial
  3. import math
  4. import operator
  5. from typing import Callable, Optional
  6.  
  7. import cutlass
  8. import cutlass.cute as cute
  9. from cutlass import const_expr, Float16, Float32, Int32, Int64, Boolean
  10. from cutlass.cute.nvgpu import cpasync, warp, warpgroup
  11. from cutlass.cutlass_dsl import T, dsl_user_op
  12. from cutlass._mlir.dialects import llvm, nvvm, vector
  13. from cutlass.utils import LayoutEnum, StaticPersistentTileScheduler
  14. import cutlass.utils.hopper_helpers as sm90_utils
  15. from cutlass.cute.runtime import from_dlpack
  16. import cuda.bindings.driver as cuda
  17. from cutlass.cute.runtime import make_ptr
  18.  
  19. @cute.jit
  20. def semaphore_acquire(
  21. lock: cute.Pointer,
  22. state: Int32,
  23. ):
  24. lock_ptr = lock.toint().ir_value()
  25. tx = cute.arch.thread_idx()[0]
  26. if tx==0 or tx==1:
  27. cute.printf("before tx={} state = {}", tx, state)
  28. if tx==0:
  29. llvm.inline_asm(
  30. None, # return
  31. [state.ir_value(), lock_ptr], # input
  32. f"ld.global.acquire.gpu.b32 $0, [$1];",
  33. f"r,l",
  34. has_side_effects=True,
  35. is_align_stack=False,
  36. asm_dialect=llvm.AsmDialect.AD_ATT,
  37. )
  38. cute.arch.sync_threads()
  39. if tx==0 or tx==1:
  40. cute.printf("tx={} state = {}", tx, state)
  41. return state
  42.  
  43. @cute.jit
  44. def semaphore_release(
  45. lock: cute.Pointer,
  46. state: Int32,
  47. ):
  48. lock_ptr = lock.toint().ir_value()
  49. tx = cute.arch.thread_idx()[0]
  50. if tx==0:
  51. llvm.inline_asm(
  52. None, # return
  53. [lock_ptr, state.ir_value()], # input
  54. f"st.global.release.gpu.b32 [$0], $1;",
  55. f"l,r",
  56. has_side_effects=True,
  57. is_align_stack=False,
  58. asm_dialect=llvm.AsmDialect.AD_ATT,
  59. )
  60.  
  61. @cute.jit
  62. def semaphore_wait(
  63. lock: cute.Pointer,
  64. need_state: Int32,
  65. ):
  66. cur_state = semaphore_acquire(lock, need_state)
  67. while cute.arch.vote_all_sync(cur_state != need_state):
  68. cur_state = semaphore_acquire(lock, need_state)
  69. cute.arch.sync_threads()
  70.  
  71.  
  72. @cute.kernel
  73. def demo_kernel(
  74. a_ptr: cute.Pointer
  75. ):
  76. tx = cute.arch.thread_idx()[0]
  77. bx = cute.arch.block_idx()[0]
  78. state = Int32(1) + bx
  79. smem = cutlass.utils.SmemAllocator()
  80. sA = smem.allocate_tensor(Int32, cute.make_layout((32, )))
  81. sA[tx] = tx + bx
  82. cute.arch.sync_threads()
  83. # semaphore_acquire(a_ptr, state)
  84. # if tx==0:
  85. # state += 3
  86. # semaphore_acquire(a_ptr, state)
  87. semaphore_wait(a_ptr, state)
  88. if tx==0:
  89. cute.printf("bx={}", bx)
  90. cute.print_tensor(sA, verbose=True)
  91. semaphore_release(a_ptr, state)
  92.  
  93.  
  94. @cute.jit
  95. def demof(
  96. a_ptr: cute.Pointer,
  97. ):
  98. demo_kernel(a_ptr).launch(
  99. grid=[2, 1, 1],
  100. block=[32, 1, 1],
  101. )
  102.  
  103. torch.ones(3, 4, device="cuda")
  104. a = torch.zeros(1, device="cuda", dtype=torch.int32)
  105. a_ptr = make_ptr(
  106. Int32, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
  107. )
  108. demof(a_ptr)
Advertisement
Add Comment
Please, Sign In to add comment