Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import triton
- import triton.language as tl
- from torch._inductor.ir import ReductionHint
- from torch._inductor.ir import TileHint
- from torch._inductor.triton_heuristics import AutotuneHint, reduction
- from torch._inductor.utils import instance_descriptor
- from torch._inductor import triton_helpers
- from triton.compiler.compiler import AttrsDescriptor
- @reduction(
- size_hints=[262144, 4096],
- reduction_hint=ReductionHint.DEFAULT,
- filename=__file__,
- triton_meta={'signature': {0: '*bf16', 1: '*i64', 2: '*i8', 3: '*bf16', 4: '*i8', 5: '*bf16', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(8, 9))]},
- inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_bmm_9', 'mutated_arg_names': [], 'no_x_dim': False}
- )
- @triton.jit
- def triton_red_fused_bmm_9(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
- rnumel = 4096
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- rbase = tl.arange(0, RBLOCK)[None, :]
- x1 = (xindex // 28672)
- x0 = xindex % 28672
- tmp2 = tl.load(in_ptr1 + ((2*x1) + (x0 // 14336)), None, eviction_policy='evict_last')
- _tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
- x3 = xindex
- _tmp22 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
- for roffset in range(0, rnumel, RBLOCK):
- rindex = roffset + rbase
- rmask = rindex < rnumel
- r2 = rindex
- tmp0 = tl.load(in_ptr0 + (r2 + (4096*x1)), None, eviction_policy='evict_last').to(tl.float32)
- tmp1 = tmp0.to(tl.float32)
- tmp3 = tmp2 + 8
- tmp4 = tmp2 < 0
- tmp5 = tl.where(tmp4, tmp3, tmp2)
- tmp6 = tl.load(in_ptr2 + (r2 + (4096*(x0 % 14336)) + (58720256*tmp5)), None, eviction_policy='evict_first')
- tmp7 = tmp6.to(tl.float32)
- tmp8 = tl.load(in_ptr3 + ((14336*tmp5) + (x0 % 14336)), None, eviction_policy='evict_first').to(tl.float32)
- tmp9 = tmp7 * tmp8
- tmp10 = tmp9.to(tl.float32)
- tmp11 = tmp1 * tmp10
- tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
- tmp14 = _tmp13 + tmp12
- _tmp13 = tmp14
- tmp15 = tl.load(in_ptr4 + (r2 + (4096*(x0 % 14336)) + (58720256*tmp5)), None, eviction_policy='evict_first')
- tmp16 = tmp15.to(tl.float32)
- tmp17 = tl.load(in_ptr5 + ((14336*tmp5) + (x0 % 14336)), None, eviction_policy='evict_first').to(tl.float32)
- tmp18 = tmp16 * tmp17
- tmp19 = tmp18.to(tl.float32)
- tmp20 = tmp1 * tmp19
- tmp21 = tl.broadcast_to(tmp20, [XBLOCK, RBLOCK])
- tmp23 = _tmp22 + tmp21
- _tmp22 = tmp23
- tmp13 = tl.sum(_tmp13, 1)[:, None]
- tl.store(out_ptr0 + (x3), tmp13, None)
- tmp22 = tl.sum(_tmp22, 1)[:, None]
- tl.store(out_ptr1 + (x3), tmp22, None)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement