Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import triton
- import triton.language as tl
- def benchmark(f, job_name: str):
- warmup = 20
- rep = 100
- ms = triton.testing.do_bench(f, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
- print(f"runtime of {job_name} in milliseconds: {ms[1]:.5f} ({ms[0]:.5f}, {ms[2]:.5f}).")
- return ms[0]
- @triton.jit
- def k_cross_mm_triton(
- k_ptr, v_ptr, res_ptr, i_ptr, j_ptr,
- N, d, num_pairs,
- stride_k_row, stride_k_col,
- stride_v_row, stride_v_col,
- stride_res_row, stride_res_col,
- BLOCK_SIZE_N: tl.constexpr,
- ):
- pair_idx = tl.program_id(0)
- col_idx = tl.program_id(1)
- # Load precomputed i, j indices
- i = tl.load(i_ptr + pair_idx)
- j = tl.load(j_ptr + pair_idx)
- dot_product = tl.zeros((), dtype=tl.float32)
- for row_start in range(0, N, BLOCK_SIZE_N):
- row_range = row_start + tl.arange(0, BLOCK_SIZE_N)
- mask = row_range < N
- k_i_offsets = i * stride_k_col + row_range * stride_k_row
- k_j_offsets = j * stride_k_col + row_range * stride_k_row
- k_i = tl.load(k_ptr + k_i_offsets, mask=mask, other=0.0)
- k_j = tl.load(k_ptr + k_j_offsets, mask=mask, other=0.0)
- k_cross = k_i * k_j
- v_offsets = row_range * stride_v_row + col_idx * stride_v_col
- v_vals = tl.load(v_ptr + v_offsets, mask=mask, other=0.0)
- dot_product += tl.sum(k_cross * v_vals, axis=0)
- res_offsets = pair_idx * stride_res_row + col_idx * stride_res_col
- tl.store(res_ptr + res_offsets, 2 * dot_product)
- def k_cross_mm(k, v):
- N, d = k.shape
- num_pairs = d * (d - 1) // 2
- res = torch.empty((num_pairs, d), dtype=torch.float32, device='cuda')
- # Precompute i, j indices
- k_triu_indices = torch.triu_indices(d, d, offset=1, device='cuda')
- i_indices = k_triu_indices[0].contiguous()
- j_indices = k_triu_indices[1].contiguous()
- grid = (num_pairs, d)
- BLOCK_SIZE_N = 8192
- k_cross_mm_triton[grid](
- k, v, res, i_indices, j_indices,
- N, d, num_pairs,
- k.stride(0), k.stride(1),
- v.stride(0), v.stride(1),
- res.stride(0), res.stride(1),
- BLOCK_SIZE_N=BLOCK_SIZE_N,
- )
- return res
- def matrix_mul_kv_original(k, v):
- _, d = k.shape
- k_t = k.transpose(-2, -1).contiguous()
- k_triu_indices = torch.triu_indices(d, d, offset=1)
- k_cross = (k_t[k_triu_indices[0]] * k_t[k_triu_indices[1]]).contiguous()
- res = 2 * torch.mm(k_cross, v)
- return res
- if __name__ == "__main__":
- N, d = 80000, 128
- torch.manual_seed(42)
- k = torch.randn((N, d), dtype=torch.float32, device='cuda')
- v = torch.randn((N, d), dtype=torch.float32, device='cuda')
- kv_kernel_triton = lambda: k_cross_mm(k, v)
- runtime_triton_kernel = benchmark(kv_kernel_triton, "Optimized Triton Cross Matrix Multiply")
- kv_kernel_original = lambda: matrix_mul_kv_original(k, v)
- runtime_pytorch_kernel = benchmark(kv_kernel_original, "Original PyTorch Cross Matrix Multiply")
- res_triton = k_cross_mm(k, v)
- res_pytorch = matrix_mul_kv_original(k, v)
- print(res_triton.shape, res_pytorch.shape)
- print(torch.norm(res_pytorch - res_triton, p='fro') / torch.norm(res_pytorch, p='fro'))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement