Guest User

Untitled

a guest
Jul 30th, 2025
18
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.67 KB | None | 0 0
  1. import torch
  2. from functools import partial
  3.  
  4. import cutlass
  5. import cutlass.cute as cute
  6. from cutlass.cute.runtime import from_dlpack
  7.  
  8. from triton.testing import do_bench
  9.  
  10. @cute.jit
  11. def vector_mm(a: cute.Tensor, b: cute.Tensor, tile_K_idx):
  12. cute.arch.sync_threads()
  13. broadcast_a = cute.make_tensor(a.iterator, cute.append(a.layout, cute.make_layout(b.shape[1], stride=0)))
  14. cute.arch.sync_threads()
  15. broadcast_b = cute.make_tensor(b.iterator, cute.prepend(b.layout, cute.make_layout(a.shape[0], stride=0)))
  16. cute.arch.sync_threads()
  17.  
  18. tidx, tidy, _ = cute.arch.thread_idx()
  19. bidx, bidy, _ = cute.arch.block_idx()
  20.  
  21. if max(tidx, tidy, bidx, bidy, tile_K_idx) == 0:
  22. cute.printf("broadcast_a: ")
  23. cute.printf(broadcast_a)
  24. cute.printf("broadcast_b: ")
  25. cute.printf(broadcast_b)
  26.  
  27. cute.arch.sync_threads()
  28. broadcast_a = broadcast_a.load()
  29. cute.arch.sync_threads()
  30. broadcast_b = broadcast_b.load()
  31. cute.arch.sync_threads()
  32.  
  33. mul_res = (broadcast_a.to(cute.Float32) * broadcast_b.to(cute.Float32))
  34. cute.arch.sync_threads()
  35. mul_res_frag = cute.make_fragment(mul_res.shape, dtype=cute.Float32)
  36. cute.arch.sync_threads()
  37. mul_res_frag.store(mul_res)
  38. cute.arch.sync_threads()
  39. if max(tidx, tidy, bidx, bidy, tile_K_idx) == 0:
  40. cute.printf("mul_res: ")
  41. cute.printf(mul_res_frag)
  42. return mul_res.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(None, 1, None))
  43.  
  44. @cute.jit
  45. def loop_mm(a: cute.Tensor, b: cute.Tensor, tile_K_idx):
  46. tmp_val = cute.make_fragment((a.shape[0], b.shape[1]), dtype=cute.Float32)
  47. for m in cutlass.range_constexpr(a.shape[0]):
  48. for n in cutlass.range_constexpr(b.shape[1]):
  49. tmp_val[(m, n)] = 0.0
  50. for k in cutlass.range_constexpr(a.shape[1]):
  51. tmp_val[(m, n)] += a[(m, k)] * b[(k, n)]
  52. return tmp_val.load()
  53.  
  54.  
  55. @cute.kernel
  56. def matmul_kernel(
  57. mA: cute.Tensor,
  58. mB: cute.Tensor,
  59. mC: cute.Tensor,
  60. blk: cutlass.Constexpr,
  61. ):
  62. tidx, tidy, _ = cute.arch.thread_idx()
  63. bidx, bidy, _ = cute.arch.block_idx()
  64. bdim, _, _ = cute.arch.block_dim()
  65.  
  66.  
  67. tile_M_idx = bidx
  68. tile_N_idx = bidy
  69.  
  70. gA = cute.local_tile(mA, (blk, blk), coord=(tile_M_idx, None))
  71. gB = cute.local_tile(mB, (blk, blk), coord=(None, tile_N_idx))
  72. gC = cute.local_tile(mC, (blk, blk), coord=(tile_M_idx, tile_N_idx))
  73.  
  74.  
  75. # Passing smem
  76. # sA_layout = cute.make_layout((blk, blk), stride=(1, blk))
  77. # sB_layout = cute.make_layout((blk, blk), stride=(blk, 1))
  78. # Failing smem
  79. sA_layout = cute.make_layout((blk, blk), stride=(blk, 1))
  80. sB_layout = cute.make_layout((blk, blk), stride=(1, blk))
  81.  
  82. smem = cutlass.utils.SmemAllocator()
  83. sA = smem.allocate_tensor(gA.element_type, sA_layout)
  84. sB = smem.allocate_tensor(gB.element_type, sB_layout)
  85.  
  86. num_k_blocks = gA.shape[2]
  87. tmp_val = cute.make_fragment((1, 1), dtype=cute.Float32)
  88. tmp_val.fill(0.0)
  89. for tile_K_idx in range(num_k_blocks):
  90. cute.arch.sync_threads()
  91. sA[(tidy, tidx)] = gA[(tidy, tidx, tile_K_idx)]
  92. sB[(tidy, tidx)] = gB[(tidy, tidx, tile_K_idx)]
  93. cute.arch.sync_threads()
  94.  
  95. local_a = sA[(tidx, None)]
  96. local_b = sB[(None, tidy)]
  97. local_a = cute.local_tile(sA, (1, blk), coord=(tidx, 0))
  98. local_b = cute.local_tile(sB, (blk, 1), coord=(0, tidy))
  99. cur = vector_mm(local_a, local_b, tile_K_idx)
  100. # cur = loop_mm(local_a, local_b, tile_K_idx) # Passes with either smem layout
  101. tmp_val.store(tmp_val.load() + cur)
  102.  
  103. gC[(tidx, tidy)] = tmp_val[(0,0)].to(cute.Float16)
  104.  
  105.  
  106.  
  107. @cute.jit
  108. def matmul(
  109. mA: cute.Tensor,
  110. mB: cute.Tensor,
  111. mC: cute.Tensor
  112. ):
  113. m, k1 = mA.shape
  114. k2, n = mB.shape
  115. blk = 16
  116. smem_size = cute.size_in_bytes(mA.element_type, cute.make_layout((blk, blk), stride=(blk, 1))) + cute.size_in_bytes(mA.element_type, cute.make_layout((blk, blk), stride=(1, blk)))
  117. kernel = matmul_kernel(mA, mB, mC, blk)
  118. kernel.launch(grid=(cute.ceil_div(mA.shape[0], blk), cute.ceil_div(mB.shape[1], blk), 1), block=(blk, blk, 1), smem=smem_size)
  119.  
  120.  
  121. def test():
  122. torch.manual_seed(0)
  123. N = 16
  124. M = 16
  125. K = 16
  126. a = torch.ones(M, K, device='cuda', dtype=torch.float16)
  127. b = torch.ones(K, N, device='cuda', dtype=torch.float16)
  128. out = torch.zeros(N, M, device='cuda', dtype=torch.float16).T
  129. a_ = from_dlpack(a, assumed_align=16)
  130. b_ = from_dlpack(b, assumed_align=16)
  131. out_ = from_dlpack(out, assumed_align=16)
  132. matmul(a_, b_, out_)
  133.  
  134. ref = torch.mm(a, b)
  135. torch.testing.assert_close(out, ref)
  136. print("Test passed")
  137. test()
Advertisement
Add Comment
Please, Sign In to add comment