Advertisement
Guest User

Untitled

a guest
Feb 26th, 2024
179
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.01 KB | None | 0 0
  1.  
  2. import triton
  3. import triton.language as tl
  4. from torch._inductor.ir import ReductionHint
  5. from torch._inductor.ir import TileHint
  6. from torch._inductor.triton_heuristics import AutotuneHint, reduction
  7. from torch._inductor.utils import instance_descriptor
  8. from torch._inductor import triton_helpers
  9. from triton.compiler.compiler import AttrsDescriptor
  10.  
  11. @reduction(
  12. size_hints=[262144, 4096],
  13. reduction_hint=ReductionHint.DEFAULT,
  14. filename=__file__,
  15. triton_meta={'signature': {0: '*bf16', 1: '*i64', 2: '*i8', 3: '*bf16', 4: '*i8', 5: '*bf16', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(8, 9))]},
  16. inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_bmm_9', 'mutated_arg_names': [], 'no_x_dim': False}
  17. )
  18. @triton.jit
  19. def triton_red_fused_bmm_9(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
  20. rnumel = 4096
  21. xoffset = tl.program_id(0) * XBLOCK
  22. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  23. xmask = xindex < xnumel
  24. rbase = tl.arange(0, RBLOCK)[None, :]
  25. x1 = (xindex // 28672)
  26. x0 = xindex % 28672
  27. tmp2 = tl.load(in_ptr1 + ((2*x1) + (x0 // 14336)), None, eviction_policy='evict_last')
  28. _tmp13 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
  29. x3 = xindex
  30. _tmp22 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
  31. for roffset in range(0, rnumel, RBLOCK):
  32. rindex = roffset + rbase
  33. rmask = rindex < rnumel
  34. r2 = rindex
  35. tmp0 = tl.load(in_ptr0 + (r2 + (4096*x1)), None, eviction_policy='evict_last').to(tl.float32)
  36. tmp1 = tmp0.to(tl.float32)
  37. tmp3 = tmp2 + 8
  38. tmp4 = tmp2 < 0
  39. tmp5 = tl.where(tmp4, tmp3, tmp2)
  40. tmp6 = tl.load(in_ptr2 + (r2 + (4096*(x0 % 14336)) + (58720256*tmp5)), None, eviction_policy='evict_first')
  41. tmp7 = tmp6.to(tl.float32)
  42. tmp8 = tl.load(in_ptr3 + ((14336*tmp5) + (x0 % 14336)), None, eviction_policy='evict_first').to(tl.float32)
  43. tmp9 = tmp7 * tmp8
  44. tmp10 = tmp9.to(tl.float32)
  45. tmp11 = tmp1 * tmp10
  46. tmp12 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
  47. tmp14 = _tmp13 + tmp12
  48. _tmp13 = tmp14
  49. tmp15 = tl.load(in_ptr4 + (r2 + (4096*(x0 % 14336)) + (58720256*tmp5)), None, eviction_policy='evict_first')
  50. tmp16 = tmp15.to(tl.float32)
  51. tmp17 = tl.load(in_ptr5 + ((14336*tmp5) + (x0 % 14336)), None, eviction_policy='evict_first').to(tl.float32)
  52. tmp18 = tmp16 * tmp17
  53. tmp19 = tmp18.to(tl.float32)
  54. tmp20 = tmp1 * tmp19
  55. tmp21 = tl.broadcast_to(tmp20, [XBLOCK, RBLOCK])
  56. tmp23 = _tmp22 + tmp21
  57. _tmp22 = tmp23
  58. tmp13 = tl.sum(_tmp13, 1)[:, None]
  59. tl.store(out_ptr0 + (x3), tmp13, None)
  60. tmp22 = tl.sum(_tmp22, 1)[:, None]
  61. tl.store(out_ptr1 + (x3), tmp22, None)
  62.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement