Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from functools import partial
- import cutlass
- import cutlass.cute as cute
- from cutlass.cute.runtime import from_dlpack
- from triton.testing import do_bench
- @cute.jit
- def vector_mm(a: cute.Tensor, b: cute.Tensor, tile_K_idx):
- cute.arch.sync_threads()
- broadcast_a = cute.make_tensor(a.iterator, cute.append(a.layout, cute.make_layout(b.shape[1], stride=0)))
- cute.arch.sync_threads()
- broadcast_b = cute.make_tensor(b.iterator, cute.prepend(b.layout, cute.make_layout(a.shape[0], stride=0)))
- cute.arch.sync_threads()
- tidx, tidy, _ = cute.arch.thread_idx()
- bidx, bidy, _ = cute.arch.block_idx()
- if max(tidx, tidy, bidx, bidy, tile_K_idx) == 0:
- cute.printf("broadcast_a: ")
- cute.printf(broadcast_a)
- cute.printf("broadcast_b: ")
- cute.printf(broadcast_b)
- cute.arch.sync_threads()
- broadcast_a = broadcast_a.load()
- cute.arch.sync_threads()
- broadcast_b = broadcast_b.load()
- cute.arch.sync_threads()
- mul_res = (broadcast_a.to(cute.Float32) * broadcast_b.to(cute.Float32))
- cute.arch.sync_threads()
- mul_res_frag = cute.make_fragment(mul_res.shape, dtype=cute.Float32)
- cute.arch.sync_threads()
- mul_res_frag.store(mul_res)
- cute.arch.sync_threads()
- if max(tidx, tidy, bidx, bidy, tile_K_idx) == 0:
- cute.printf("mul_res: ")
- cute.printf(mul_res_frag)
- return mul_res.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(None, 1, None))
- @cute.jit
- def loop_mm(a: cute.Tensor, b: cute.Tensor, tile_K_idx):
- tmp_val = cute.make_fragment((a.shape[0], b.shape[1]), dtype=cute.Float32)
- for m in cutlass.range_constexpr(a.shape[0]):
- for n in cutlass.range_constexpr(b.shape[1]):
- tmp_val[(m, n)] = 0.0
- for k in cutlass.range_constexpr(a.shape[1]):
- tmp_val[(m, n)] += a[(m, k)] * b[(k, n)]
- return tmp_val.load()
- @cute.kernel
- def matmul_kernel(
- mA: cute.Tensor,
- mB: cute.Tensor,
- mC: cute.Tensor,
- blk: cutlass.Constexpr,
- ):
- tidx, tidy, _ = cute.arch.thread_idx()
- bidx, bidy, _ = cute.arch.block_idx()
- bdim, _, _ = cute.arch.block_dim()
- tile_M_idx = bidx
- tile_N_idx = bidy
- gA = cute.local_tile(mA, (blk, blk), coord=(tile_M_idx, None))
- gB = cute.local_tile(mB, (blk, blk), coord=(None, tile_N_idx))
- gC = cute.local_tile(mC, (blk, blk), coord=(tile_M_idx, tile_N_idx))
- # Passing smem
- # sA_layout = cute.make_layout((blk, blk), stride=(1, blk))
- # sB_layout = cute.make_layout((blk, blk), stride=(blk, 1))
- # Failing smem
- sA_layout = cute.make_layout((blk, blk), stride=(blk, 1))
- sB_layout = cute.make_layout((blk, blk), stride=(1, blk))
- smem = cutlass.utils.SmemAllocator()
- sA = smem.allocate_tensor(gA.element_type, sA_layout)
- sB = smem.allocate_tensor(gB.element_type, sB_layout)
- num_k_blocks = gA.shape[2]
- tmp_val = cute.make_fragment((1, 1), dtype=cute.Float32)
- tmp_val.fill(0.0)
- for tile_K_idx in range(num_k_blocks):
- cute.arch.sync_threads()
- sA[(tidy, tidx)] = gA[(tidy, tidx, tile_K_idx)]
- sB[(tidy, tidx)] = gB[(tidy, tidx, tile_K_idx)]
- cute.arch.sync_threads()
- local_a = sA[(tidx, None)]
- local_b = sB[(None, tidy)]
- local_a = cute.local_tile(sA, (1, blk), coord=(tidx, 0))
- local_b = cute.local_tile(sB, (blk, 1), coord=(0, tidy))
- cur = vector_mm(local_a, local_b, tile_K_idx)
- # cur = loop_mm(local_a, local_b, tile_K_idx) # Passes with either smem layout
- tmp_val.store(tmp_val.load() + cur)
- gC[(tidx, tidy)] = tmp_val[(0,0)].to(cute.Float16)
- @cute.jit
- def matmul(
- mA: cute.Tensor,
- mB: cute.Tensor,
- mC: cute.Tensor
- ):
- m, k1 = mA.shape
- k2, n = mB.shape
- blk = 16
- smem_size = cute.size_in_bytes(mA.element_type, cute.make_layout((blk, blk), stride=(blk, 1))) + cute.size_in_bytes(mA.element_type, cute.make_layout((blk, blk), stride=(1, blk)))
- kernel = matmul_kernel(mA, mB, mC, blk)
- kernel.launch(grid=(cute.ceil_div(mA.shape[0], blk), cute.ceil_div(mB.shape[1], blk), 1), block=(blk, blk, 1), smem=smem_size)
- def test():
- torch.manual_seed(0)
- N = 16
- M = 16
- K = 16
- a = torch.ones(M, K, device='cuda', dtype=torch.float16)
- b = torch.ones(K, N, device='cuda', dtype=torch.float16)
- out = torch.zeros(N, M, device='cuda', dtype=torch.float16).T
- a_ = from_dlpack(a, assumed_align=16)
- b_ = from_dlpack(b, assumed_align=16)
- out_ = from_dlpack(out, assumed_align=16)
- matmul(a_, b_, out_)
- ref = torch.mm(a, b)
- torch.testing.assert_close(out, ref)
- print("Test passed")
- test()
Advertisement
Add Comment
Please, Sign In to add comment