Advertisement
Guest User

Untitled

a guest
Feb 4th, 2025
51
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.21 KB | None | 0 0
  1. import torch
  2. import triton
  3. import triton.language as tl
  4.  
  5. def benchmark(f, job_name: str):
  6. warmup = 20
  7. rep = 100
  8. ms = triton.testing.do_bench(f, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
  9. print(f"runtime of {job_name} in milliseconds: {ms[1]:.5f} ({ms[0]:.5f}, {ms[2]:.5f}).")
  10. return ms[0]
  11.  
  12. @triton.jit
  13. def k_cross_mm_triton(
  14. k_ptr, v_ptr, res_ptr, i_ptr, j_ptr,
  15. N, d, num_pairs,
  16. stride_k_row, stride_k_col,
  17. stride_v_row, stride_v_col,
  18. stride_res_row, stride_res_col,
  19. BLOCK_SIZE_N: tl.constexpr,
  20. ):
  21. pair_idx = tl.program_id(0)
  22. col_idx = tl.program_id(1)
  23.  
  24. # Load precomputed i, j indices
  25. i = tl.load(i_ptr + pair_idx)
  26. j = tl.load(j_ptr + pair_idx)
  27.  
  28. dot_product = tl.zeros((), dtype=tl.float32)
  29.  
  30. for row_start in range(0, N, BLOCK_SIZE_N):
  31. row_range = row_start + tl.arange(0, BLOCK_SIZE_N)
  32. mask = row_range < N
  33.  
  34. k_i_offsets = i * stride_k_col + row_range * stride_k_row
  35. k_j_offsets = j * stride_k_col + row_range * stride_k_row
  36.  
  37. k_i = tl.load(k_ptr + k_i_offsets, mask=mask, other=0.0)
  38. k_j = tl.load(k_ptr + k_j_offsets, mask=mask, other=0.0)
  39.  
  40. k_cross = k_i * k_j
  41.  
  42. v_offsets = row_range * stride_v_row + col_idx * stride_v_col
  43. v_vals = tl.load(v_ptr + v_offsets, mask=mask, other=0.0)
  44.  
  45. dot_product += tl.sum(k_cross * v_vals, axis=0)
  46.  
  47. res_offsets = pair_idx * stride_res_row + col_idx * stride_res_col
  48. tl.store(res_ptr + res_offsets, 2 * dot_product)
  49.  
  50.  
  51. def k_cross_mm(k, v):
  52. N, d = k.shape
  53. num_pairs = d * (d - 1) // 2
  54.  
  55. res = torch.empty((num_pairs, d), dtype=torch.float32, device='cuda')
  56.  
  57. # Precompute i, j indices
  58. k_triu_indices = torch.triu_indices(d, d, offset=1, device='cuda')
  59. i_indices = k_triu_indices[0].contiguous()
  60. j_indices = k_triu_indices[1].contiguous()
  61.  
  62. grid = (num_pairs, d)
  63. BLOCK_SIZE_N = 8192
  64.  
  65. k_cross_mm_triton[grid](
  66. k, v, res, i_indices, j_indices,
  67. N, d, num_pairs,
  68. k.stride(0), k.stride(1),
  69. v.stride(0), v.stride(1),
  70. res.stride(0), res.stride(1),
  71. BLOCK_SIZE_N=BLOCK_SIZE_N,
  72. )
  73. return res
  74.  
  75.  
  76. def matrix_mul_kv_original(k, v):
  77. _, d = k.shape
  78. k_t = k.transpose(-2, -1).contiguous()
  79. k_triu_indices = torch.triu_indices(d, d, offset=1)
  80. k_cross = (k_t[k_triu_indices[0]] * k_t[k_triu_indices[1]]).contiguous()
  81. res = 2 * torch.mm(k_cross, v)
  82. return res
  83.  
  84. if __name__ == "__main__":
  85. N, d = 80000, 128
  86. torch.manual_seed(42)
  87. k = torch.randn((N, d), dtype=torch.float32, device='cuda')
  88. v = torch.randn((N, d), dtype=torch.float32, device='cuda')
  89.  
  90. kv_kernel_triton = lambda: k_cross_mm(k, v)
  91. runtime_triton_kernel = benchmark(kv_kernel_triton, "Optimized Triton Cross Matrix Multiply")
  92.  
  93. kv_kernel_original = lambda: matrix_mul_kv_original(k, v)
  94. runtime_pytorch_kernel = benchmark(kv_kernel_original, "Original PyTorch Cross Matrix Multiply")
  95.  
  96. res_triton = k_cross_mm(k, v)
  97. res_pytorch = matrix_mul_kv_original(k, v)
  98.  
  99. print(res_triton.shape, res_pytorch.shape)
  100. print(torch.norm(res_pytorch - res_triton, p='fro') / torch.norm(res_pytorch, p='fro'))
  101.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement