Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- from functools import partial
- import math
- import operator
- from typing import Callable, Optional
- import cutlass
- import cutlass.cute as cute
- from cutlass import const_expr, Float16, Float32, Int32, Int64, Boolean
- from cutlass.cute.nvgpu import cpasync, warp, warpgroup
- from cutlass.cutlass_dsl import T, dsl_user_op
- from cutlass._mlir.dialects import llvm, nvvm, vector
- from cutlass.utils import LayoutEnum, StaticPersistentTileScheduler
- import cutlass.utils.hopper_helpers as sm90_utils
- from cutlass.cute.runtime import from_dlpack
- import cuda.bindings.driver as cuda
- from cutlass.cute.runtime import make_ptr
- @cute.jit
- def semaphore_acquire(
- lock: cute.Pointer,
- state: Int32,
- ):
- lock_ptr = lock.toint().ir_value()
- tx = cute.arch.thread_idx()[0]
- if tx==0 or tx==1:
- cute.printf("before tx={} state = {}", tx, state)
- if tx==0:
- llvm.inline_asm(
- None, # return
- [state.ir_value(), lock_ptr], # input
- f"ld.global.acquire.gpu.b32 $0, [$1];",
- f"r,l",
- has_side_effects=True,
- is_align_stack=False,
- asm_dialect=llvm.AsmDialect.AD_ATT,
- )
- cute.arch.sync_threads()
- if tx==0 or tx==1:
- cute.printf("tx={} state = {}", tx, state)
- return state
- @cute.jit
- def semaphore_release(
- lock: cute.Pointer,
- state: Int32,
- ):
- lock_ptr = lock.toint().ir_value()
- tx = cute.arch.thread_idx()[0]
- if tx==0:
- llvm.inline_asm(
- None, # return
- [lock_ptr, state.ir_value()], # input
- f"st.global.release.gpu.b32 [$0], $1;",
- f"l,r",
- has_side_effects=True,
- is_align_stack=False,
- asm_dialect=llvm.AsmDialect.AD_ATT,
- )
- @cute.jit
- def semaphore_wait(
- lock: cute.Pointer,
- need_state: Int32,
- ):
- cur_state = semaphore_acquire(lock, need_state)
- while cute.arch.vote_all_sync(cur_state != need_state):
- cur_state = semaphore_acquire(lock, need_state)
- cute.arch.sync_threads()
- @cute.kernel
- def demo_kernel(
- a_ptr: cute.Pointer
- ):
- tx = cute.arch.thread_idx()[0]
- bx = cute.arch.block_idx()[0]
- state = Int32(1) + bx
- smem = cutlass.utils.SmemAllocator()
- sA = smem.allocate_tensor(Int32, cute.make_layout((32, )))
- sA[tx] = tx + bx
- cute.arch.sync_threads()
- # semaphore_acquire(a_ptr, state)
- # if tx==0:
- # state += 3
- # semaphore_acquire(a_ptr, state)
- semaphore_wait(a_ptr, state)
- if tx==0:
- cute.printf("bx={}", bx)
- cute.print_tensor(sA, verbose=True)
- semaphore_release(a_ptr, state)
- @cute.jit
- def demof(
- a_ptr: cute.Pointer,
- ):
- demo_kernel(a_ptr).launch(
- grid=[2, 1, 1],
- block=[32, 1, 1],
- )
- torch.ones(3, 4, device="cuda")
- a = torch.zeros(1, device="cuda", dtype=torch.int32)
- a_ptr = make_ptr(
- Int32, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
- )
- demof(a_ptr)
Advertisement
Add Comment
Please, Sign In to add comment