Guest User

Untitled

a guest
May 8th, 2024
201
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 111.92 KB | None | 0 0
  1.  
  2. # AOT ID: ['0_inference']
  3. from ctypes import c_void_p, c_long
  4. import torch
  5. import math
  6. import random
  7. import os
  8. import tempfile
  9. from math import inf, nan
  10. from torch._inductor.hooks import run_intermediate_hooks
  11. from torch._inductor.utils import maybe_profile
  12. from torch._inductor.codegen.memory_planning import _align as align
  13.  
  14. from torch import device, empty_strided
  15. from torch._inductor.codecache import AsyncCompile
  16. from torch._inductor.select_algorithm import extern_kernels
  17. from torch._inductor.codegen.multi_kernel import MultiKernelCall
  18.  
  19. aten = torch.ops.aten
  20. inductor_ops = torch.ops.inductor
  21. _quantized = torch.ops._quantized
  22. assert_size_stride = torch._C._dynamo.guards.assert_size_stride
  23. empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
  24. empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
  25. alloc_from_pool = torch.ops.inductor._alloc_from_pool
  26. reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
  27.  
  28.  
  29.  
  30. # kernel path: /tmp/torchinductor_chilli/ey/ceyko6dtcfce7u2l2mwcbbrcenhuy63ap6ko6ppldgjdojixko7k.py
  31. # Source Nodes: [x], Original ATen: [aten.convolution]
  32. # x => convolution
  33. import triton
  34. import triton.language as tl
  35. from triton.compiler.compiler import AttrsDescriptor
  36.  
  37. from torch._inductor.runtime import triton_helpers, triton_heuristics
  38. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  39. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  40.  
  41. @triton.jit
  42. def triton_poi_fused_convolution_0(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
  43. ynumel = 3
  44. xnumel = 51984
  45. yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
  46. yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
  47. ymask = yindex < ynumel
  48. xoffset = tl.program_id(0) * XBLOCK
  49. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  50. xmask = xindex < xnumel
  51. x1 = xindex
  52. y0 = yindex
  53. tmp0 = tl.load(in_ptr0 + (x1 + (51984*y0)), xmask & ymask, eviction_policy='evict_last')
  54. tl.store(out_ptr0 + (y0 + (3*x1)), tmp0, xmask & ymask)
  55.  
  56.  
  57.  
  58. import triton
  59. import triton.language as tl
  60. from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
  61. from torch._C import _cuda_getCurrentRawStream as get_raw_stream
  62.  
  63.  
  64. # kernel path: /tmp/torchinductor_chilli/fg/cfgog4hjgpnrqznjhdl7s57zvld6hr66mq3j44s7xmbtaewazh33.py
  65. # Source Nodes: [x], Original ATen: [aten.convolution]
  66. # x => convolution
  67. import triton
  68. import triton.language as tl
  69. from triton.compiler.compiler import AttrsDescriptor
  70.  
  71. from torch._inductor.runtime import triton_helpers, triton_heuristics
  72. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  73. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  74.  
  75. @triton.jit
  76. def triton_poi_fused_convolution_1(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
  77. ynumel = 192
  78. xnumel = 49
  79. yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
  80. yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
  81. ymask = yindex < ynumel
  82. xoffset = tl.program_id(0) * XBLOCK
  83. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  84. xmask = xindex < xnumel
  85. x2 = xindex
  86. y3 = yindex
  87. y0 = yindex % 3
  88. y1 = (yindex // 3)
  89. tmp0 = tl.load(in_ptr0 + (x2 + (49*y3)), xmask & ymask, eviction_policy='evict_last')
  90. tl.store(out_ptr0 + (y0 + (3*x2) + (147*y1)), tmp0, xmask & ymask)
  91.  
  92.  
  93.  
  94.  
  95. # kernel path: /tmp/torchinductor_chilli/oy/coyvppj4wtlyylbacoaiwv7r5c2bcteoyqnizkfxwnizxz3hmqxh.py
  96. # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
  97. # x_1 => var_mean
  98. import triton
  99. import triton.language as tl
  100. from triton.compiler.compiler import AttrsDescriptor
  101.  
  102. from torch._inductor.runtime import triton_helpers, triton_heuristics
  103. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  104. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  105.  
  106. @triton.jit
  107. def triton_per_fused__native_batch_norm_legit_functional_2(in_ptr0, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
  108. xnumel = 6528
  109. rnumel = 128
  110. RBLOCK: tl.constexpr = 128
  111. xoffset = tl.program_id(0) * XBLOCK
  112. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  113. xmask = xindex < xnumel
  114. rindex = tl.arange(0, RBLOCK)[None, :]
  115. roffset = 0
  116. rmask = rindex < rnumel
  117. r3 = rindex
  118. x1 = (xindex // 64) % 51
  119. x0 = xindex % 64
  120. x2 = (xindex // 3264)
  121. x4 = xindex
  122. tmp0 = r3 + (128*x1)
  123. tmp1 = tl.full([1, 1], 6498, tl.int32)
  124. tmp2 = tmp0 < tmp1
  125. tmp3 = tl.load(in_ptr0 + (x0 + (64*((r3 + (128*x1)) % 114)) + (7296*(((r3 + (128*x1) + (6498*x2)) // 114) % 114))), rmask & tmp2 & xmask, other=0.0)
  126. tmp4 = tl.full(tmp3.shape, 0, tmp3.dtype)
  127. tmp5 = tl.where(tmp2, tmp3, tmp4)
  128. tmp6 = 0.0
  129. tmp7 = tl.full(tmp6.shape, 0, tmp6.dtype)
  130. tmp8 = tl.where(tmp2, tmp6, tmp7)
  131. tmp9 = 1.0
  132. tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
  133. tmp11 = tl.where(tmp2, tmp9, tmp10)
  134. tmp12 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
  135. tmp13 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
  136. tmp14 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
  137. tmp16 = tl.where(rmask & xmask, tmp12, 0)
  138. tmp17 = tl.where(rmask & xmask, tmp13, 0)
  139. tmp18 = tl.where(rmask & xmask, tmp14, 0)
  140. tmp19, tmp20, tmp21 = triton_helpers.welford(tmp16, tmp17, tmp18, 1)
  141. tmp22 = tmp19[:, None]
  142. tmp23 = tmp20[:, None]
  143. tmp24 = tmp21[:, None]
  144. tl.store(out_ptr0 + (x4), tmp22, xmask)
  145. tl.store(out_ptr1 + (x4), tmp23, xmask)
  146. tl.store(out_ptr2 + (x4), tmp24, xmask)
  147.  
  148.  
  149.  
  150.  
  151. # kernel path: /tmp/torchinductor_chilli/hc/chcl7tgr3scofksd7kbezuz3nvq5wqe3nivxob6lllb27qsootcq.py
  152. # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
  153. # x_1 => var_mean
  154. import triton
  155. import triton.language as tl
  156. from triton.compiler.compiler import AttrsDescriptor
  157.  
  158. from torch._inductor.runtime import triton_helpers, triton_heuristics
  159. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  160. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  161.  
  162. @triton.jit
  163. def triton_per_fused__native_batch_norm_legit_functional_3(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
  164. xnumel = 128
  165. rnumel = 51
  166. RBLOCK: tl.constexpr = 64
  167. xoffset = tl.program_id(0) * XBLOCK
  168. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  169. xmask = xindex < xnumel
  170. rindex = tl.arange(0, RBLOCK)[None, :]
  171. roffset = 0
  172. rmask = rindex < rnumel
  173. r2 = rindex
  174. x0 = xindex % 64
  175. x1 = (xindex // 64)
  176. x3 = xindex
  177. tmp0 = tl.load(in_ptr0 + (x0 + (64*r2) + (3264*x1)), rmask & xmask, other=0.0)
  178. tmp1 = tl.load(in_ptr1 + (x0 + (64*r2) + (3264*x1)), rmask & xmask, other=0.0)
  179. tmp2 = tl.load(in_ptr2 + (x0 + (64*r2) + (3264*x1)), rmask & xmask, other=0.0)
  180. tmp3 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
  181. tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
  182. tmp5 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
  183. tmp7 = tl.where(rmask & xmask, tmp3, 0)
  184. tmp8 = tl.where(rmask & xmask, tmp4, 0)
  185. tmp9 = tl.where(rmask & xmask, tmp5, 0)
  186. tmp10, tmp11, tmp12 = triton_helpers.welford(tmp7, tmp8, tmp9, 1)
  187. tmp13 = tmp10[:, None]
  188. tmp14 = tmp11[:, None]
  189. tmp15 = tmp12[:, None]
  190. tl.store(out_ptr0 + (x3), tmp13, xmask)
  191. tl.store(out_ptr1 + (x3), tmp14, xmask)
  192. tl.store(out_ptr2 + (x3), tmp15, xmask)
  193.  
  194.  
  195.  
  196.  
  197. # kernel path: /tmp/torchinductor_chilli/7u/c7uiyrzseejytdgzkmlnbflxex6dricmsk7lfiz42tf55su5rnn5.py
  198. # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
  199. # x_1 => add_2, add_3, mul_1, mul_2, mul_3, mul_4, mul_5, var_mean
  200. import triton
  201. import triton.language as tl
  202. from triton.compiler.compiler import AttrsDescriptor
  203.  
  204. from torch._inductor.runtime import triton_helpers, triton_heuristics
  205. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  206. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  207.  
  208. @triton.jit
  209. def triton_per_fused__native_batch_norm_legit_functional_4(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, rnumel, XBLOCK : tl.constexpr):
  210. xnumel = 64
  211. rnumel = 2
  212. RBLOCK: tl.constexpr = 2
  213. xoffset = tl.program_id(0) * XBLOCK
  214. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  215. xmask = xindex < xnumel
  216. rindex = tl.arange(0, RBLOCK)[None, :]
  217. roffset = 0
  218. rmask = rindex < rnumel
  219. r1 = rindex
  220. x0 = xindex
  221. tmp0 = tl.load(in_ptr0 + (x0 + (64*r1)), rmask & xmask, other=0.0)
  222. tmp1 = tl.load(in_ptr1 + (x0 + (64*r1)), rmask & xmask, other=0.0)
  223. tmp2 = tl.load(in_ptr2 + (x0 + (64*r1)), rmask & xmask, other=0.0)
  224. tmp18 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
  225. tmp27 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
  226. tmp3 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
  227. tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
  228. tmp5 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
  229. tmp7 = tl.where(rmask & xmask, tmp3, 0)
  230. tmp8 = tl.where(rmask & xmask, tmp4, 0)
  231. tmp9 = tl.where(rmask & xmask, tmp5, 0)
  232. tmp10, tmp11, tmp12 = triton_helpers.welford(tmp7, tmp8, tmp9, 1)
  233. tmp13 = tmp10[:, None]
  234. tmp14 = tmp11[:, None]
  235. tmp15 = tmp12[:, None]
  236. tmp16 = 0.1
  237. tmp17 = tmp13 * tmp16
  238. tmp19 = 0.9
  239. tmp20 = tmp18 * tmp19
  240. tmp21 = tmp17 + tmp20
  241. tmp22 = 12996.0
  242. tmp23 = tmp14 / tmp22
  243. tmp24 = 1.0000769526741053
  244. tmp25 = tmp23 * tmp24
  245. tmp26 = tmp25 * tmp16
  246. tmp28 = tmp27 * tmp19
  247. tmp29 = tmp26 + tmp28
  248. tl.store(out_ptr3 + (x0), tmp21, xmask)
  249. tl.store(out_ptr5 + (x0), tmp29, xmask)
  250. tl.store(out_ptr0 + (x0), tmp13, xmask)
  251. tl.store(out_ptr1 + (x0), tmp14, xmask)
  252.  
  253.  
  254.  
  255.  
  256. # kernel path: /tmp/torchinductor_chilli/qw/cqwedz2pcqdgqlxggudqdyaimhwkci7xpj64jg25hdihc6bvvrat.py
  257. # Source Nodes: [x_1, x_2], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  258. # x_1 => add_1, add_4, mul, mul_6, rsqrt, sub, var_mean
  259. # x_2 => relu
  260. import triton
  261. import triton.language as tl
  262. from triton.compiler.compiler import AttrsDescriptor
  263.  
  264. from torch._inductor.runtime import triton_helpers, triton_heuristics
  265. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  266. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  267.  
  268. @triton.jit
  269. def triton_poi_fused__native_batch_norm_legit_functional_relu_5(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, XBLOCK : tl.constexpr):
  270. xnumel = 831744
  271. xoffset = tl.program_id(0) * XBLOCK
  272. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  273. xmask = xindex < xnumel
  274. x2 = xindex
  275. x0 = xindex % 64
  276. tmp0 = tl.load(in_out_ptr0 + (x2), xmask)
  277. tmp1 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
  278. tmp3 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  279. tmp10 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  280. tmp12 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
  281. tmp2 = tmp0 - tmp1
  282. tmp4 = 12996.0
  283. tmp5 = tmp3 / tmp4
  284. tmp6 = 1e-05
  285. tmp7 = tmp5 + tmp6
  286. tmp8 = libdevice.rsqrt(tmp7)
  287. tmp9 = tmp2 * tmp8
  288. tmp11 = tmp9 * tmp10
  289. tmp13 = tmp11 + tmp12
  290. tmp14 = triton_helpers.maximum(0, tmp13)
  291. tl.store(in_out_ptr0 + (x2), tmp14, xmask)
  292.  
  293.  
  294.  
  295.  
  296. # kernel path: /tmp/torchinductor_chilli/kc/ckc6pbzy63vjkiqb3hj3kfrrcfnpzyr4o6bd26275bq73l3asf3q.py
  297. # Source Nodes: [x_1, x_2, x_3], Original ATen: [aten._native_batch_norm_legit_functional, aten.max_pool2d_with_indices, aten.relu]
  298. # x_1 => add_1, add_4, mul, mul_6, rsqrt, sub, var_mean
  299. # x_2 => relu
  300. # x_3 => max_pool2d_with_indices
  301. import triton
  302. import triton.language as tl
  303. from triton.compiler.compiler import AttrsDescriptor
  304.  
  305. from torch._inductor.runtime import triton_helpers, triton_heuristics
  306. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  307. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  308.  
  309. @triton.jit
  310. def triton_poi_fused__native_batch_norm_legit_functional_max_pool2d_with_indices_relu_6(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
  311. xnumel = 207936
  312. xoffset = tl.program_id(0) * XBLOCK
  313. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  314. xmask = xindex < xnumel
  315. x2 = (xindex // 3648)
  316. x1 = (xindex // 64) % 57
  317. x0 = xindex % 64
  318. x4 = xindex
  319. tmp0 = (-1) + (2*x2)
  320. tmp1 = tl.full([1], 0, tl.int64)
  321. tmp2 = tmp0 >= tmp1
  322. tmp3 = tl.full([1], 114, tl.int64)
  323. tmp4 = tmp0 < tmp3
  324. tmp5 = tmp2 & tmp4
  325. tmp6 = (-1) + (2*x1)
  326. tmp7 = tmp6 >= tmp1
  327. tmp8 = tmp6 < tmp3
  328. tmp9 = tmp7 & tmp8
  329. tmp10 = tmp5 & tmp9
  330. tmp11 = tl.load(in_ptr0 + ((-7360) + x0 + (128*x1) + (14592*x2)), tmp10 & xmask, other=0.0)
  331. tmp12 = tl.full(tmp11.shape, float("-inf"), tmp11.dtype)
  332. tmp13 = tl.where(tmp10, tmp11, tmp12)
  333. tmp14 = 2*x1
  334. tmp15 = tmp14 >= tmp1
  335. tmp16 = tmp14 < tmp3
  336. tmp17 = tmp15 & tmp16
  337. tmp18 = tmp5 & tmp17
  338. tmp19 = tl.load(in_ptr0 + ((-7296) + x0 + (128*x1) + (14592*x2)), tmp18 & xmask, other=0.0)
  339. tmp20 = tl.full(tmp19.shape, float("-inf"), tmp19.dtype)
  340. tmp21 = tl.where(tmp18, tmp19, tmp20)
  341. tmp22 = triton_helpers.maximum(tmp21, tmp13)
  342. tmp23 = 1 + (2*x1)
  343. tmp24 = tmp23 >= tmp1
  344. tmp25 = tmp23 < tmp3
  345. tmp26 = tmp24 & tmp25
  346. tmp27 = tmp5 & tmp26
  347. tmp28 = tl.load(in_ptr0 + ((-7232) + x0 + (128*x1) + (14592*x2)), tmp27 & xmask, other=0.0)
  348. tmp29 = tl.full(tmp28.shape, float("-inf"), tmp28.dtype)
  349. tmp30 = tl.where(tmp27, tmp28, tmp29)
  350. tmp31 = triton_helpers.maximum(tmp30, tmp22)
  351. tmp32 = 2*x2
  352. tmp33 = tmp32 >= tmp1
  353. tmp34 = tmp32 < tmp3
  354. tmp35 = tmp33 & tmp34
  355. tmp36 = tmp35 & tmp9
  356. tmp37 = tl.load(in_ptr0 + ((-64) + x0 + (128*x1) + (14592*x2)), tmp36 & xmask, other=0.0)
  357. tmp38 = tl.full(tmp37.shape, float("-inf"), tmp37.dtype)
  358. tmp39 = tl.where(tmp36, tmp37, tmp38)
  359. tmp40 = triton_helpers.maximum(tmp39, tmp31)
  360. tmp41 = tmp35 & tmp17
  361. tmp42 = tl.load(in_ptr0 + (x0 + (128*x1) + (14592*x2)), tmp41 & xmask, other=0.0)
  362. tmp43 = tl.full(tmp42.shape, float("-inf"), tmp42.dtype)
  363. tmp44 = tl.where(tmp41, tmp42, tmp43)
  364. tmp45 = triton_helpers.maximum(tmp44, tmp40)
  365. tmp46 = tmp35 & tmp26
  366. tmp47 = tl.load(in_ptr0 + (64 + x0 + (128*x1) + (14592*x2)), tmp46 & xmask, other=0.0)
  367. tmp48 = tl.full(tmp47.shape, float("-inf"), tmp47.dtype)
  368. tmp49 = tl.where(tmp46, tmp47, tmp48)
  369. tmp50 = triton_helpers.maximum(tmp49, tmp45)
  370. tmp51 = 1 + (2*x2)
  371. tmp52 = tmp51 >= tmp1
  372. tmp53 = tmp51 < tmp3
  373. tmp54 = tmp52 & tmp53
  374. tmp55 = tmp54 & tmp9
  375. tmp56 = tl.load(in_ptr0 + (7232 + x0 + (128*x1) + (14592*x2)), tmp55 & xmask, other=0.0)
  376. tmp57 = tl.full(tmp56.shape, float("-inf"), tmp56.dtype)
  377. tmp58 = tl.where(tmp55, tmp56, tmp57)
  378. tmp59 = triton_helpers.maximum(tmp58, tmp50)
  379. tmp60 = tmp54 & tmp17
  380. tmp61 = tl.load(in_ptr0 + (7296 + x0 + (128*x1) + (14592*x2)), tmp60 & xmask, other=0.0)
  381. tmp62 = tl.full(tmp61.shape, float("-inf"), tmp61.dtype)
  382. tmp63 = tl.where(tmp60, tmp61, tmp62)
  383. tmp64 = triton_helpers.maximum(tmp63, tmp59)
  384. tmp65 = tmp54 & tmp26
  385. tmp66 = tl.load(in_ptr0 + (7360 + x0 + (128*x1) + (14592*x2)), tmp65 & xmask, other=0.0)
  386. tmp67 = tl.full(tmp66.shape, float("-inf"), tmp66.dtype)
  387. tmp68 = tl.where(tmp65, tmp66, tmp67)
  388. tmp69 = triton_helpers.maximum(tmp68, tmp64)
  389. tl.store(out_ptr0 + (x4), tmp69, xmask)
  390.  
  391.  
  392.  
  393.  
  394. # kernel path: /tmp/torchinductor_chilli/vk/cvkspjpsscasoaaqpr4bknw6rubi6rs377kkeeghxsldfkeoncma.py
  395. # Source Nodes: [out], Original ATen: [aten.convolution]
  396. # out => convolution_1
  397. import triton
  398. import triton.language as tl
  399. from triton.compiler.compiler import AttrsDescriptor
  400.  
  401. from torch._inductor.runtime import triton_helpers, triton_heuristics
  402. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  403. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  404.  
  405. @triton.jit
  406. def triton_poi_fused_convolution_7(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
  407. ynumel = 4096
  408. xnumel = 9
  409. yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
  410. yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
  411. ymask = yindex < ynumel
  412. xoffset = tl.program_id(0) * XBLOCK
  413. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  414. xmask = xindex < xnumel
  415. x2 = xindex
  416. y3 = yindex
  417. y0 = yindex % 64
  418. y1 = (yindex // 64)
  419. tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
  420. tl.store(out_ptr0 + (y0 + (64*x2) + (576*y1)), tmp0, xmask)
  421.  
  422.  
  423.  
  424.  
  425. # kernel path: /tmp/torchinductor_chilli/od/codvnm4apnzie5266x5p4vcefr4uplcosymmy6te7w4m4disvhsf.py
  426. # Source Nodes: [out_1], Original ATen: [aten._native_batch_norm_legit_functional]
  427. # out_1 => add_7, add_8, mul_10, mul_11, mul_12, mul_8, mul_9, var_mean_1
  428. import triton
  429. import triton.language as tl
  430. from triton.compiler.compiler import AttrsDescriptor
  431.  
  432. from torch._inductor.runtime import triton_helpers, triton_heuristics
  433. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  434. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  435.  
  436. @triton.jit
  437. def triton_red_fused__native_batch_norm_legit_functional_8(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
  438. xnumel = 64
  439. rnumel = 3249
  440. xoffset = tl.program_id(0) * XBLOCK
  441. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  442. xmask = xindex < xnumel
  443. rbase = tl.arange(0, RBLOCK)[None, :]
  444. x0 = xindex
  445. tmp2_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
  446. tmp2_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
  447. tmp2_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
  448. for roffset in range(0, rnumel, RBLOCK):
  449. rindex = roffset + rbase
  450. rmask = rindex < rnumel
  451. r1 = rindex
  452. tmp0 = tl.load(in_ptr0 + (x0 + (64*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0)
  453. tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
  454. tmp2_mean_next, tmp2_m2_next, tmp2_weight_next = triton_helpers.welford_reduce(
  455. tmp1, tmp2_mean, tmp2_m2, tmp2_weight, roffset == 0
  456. )
  457. tmp2_mean = tl.where(rmask & xmask, tmp2_mean_next, tmp2_mean)
  458. tmp2_m2 = tl.where(rmask & xmask, tmp2_m2_next, tmp2_m2)
  459. tmp2_weight = tl.where(rmask & xmask, tmp2_weight_next, tmp2_weight)
  460. tmp2_tmp, tmp3_tmp, tmp4_tmp = triton_helpers.welford(
  461. tmp2_mean, tmp2_m2, tmp2_weight, 1
  462. )
  463. tmp2 = tmp2_tmp[:, None]
  464. tmp3 = tmp3_tmp[:, None]
  465. tmp4 = tmp4_tmp[:, None]
  466. tl.store(out_ptr0 + (x0), tmp2, xmask)
  467. tl.store(out_ptr1 + (x0), tmp3, xmask)
  468. tmp7 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  469. tmp16 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  470. tmp5 = 0.1
  471. tmp6 = tmp2 * tmp5
  472. tmp8 = 0.9
  473. tmp9 = tmp7 * tmp8
  474. tmp10 = tmp6 + tmp9
  475. tmp11 = 3249.0
  476. tmp12 = tmp3 / tmp11
  477. tmp13 = 1.000307881773399
  478. tmp14 = tmp12 * tmp13
  479. tmp15 = tmp14 * tmp5
  480. tmp17 = tmp16 * tmp8
  481. tmp18 = tmp15 + tmp17
  482. tl.store(out_ptr3 + (x0), tmp10, xmask)
  483. tl.store(out_ptr5 + (x0), tmp18, xmask)
  484.  
  485.  
  486.  
  487.  
  488. # kernel path: /tmp/torchinductor_chilli/lo/cloanpjqs4btyhgyzyxhsxwzurvxpo2wa465u65kfdpu6mw4chyr.py
  489. # Source Nodes: [out_1, out_2], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  490. # out_1 => add_6, add_9, mul_13, mul_7, rsqrt_1, sub_1, var_mean_1
  491. # out_2 => relu_1
  492. import triton
  493. import triton.language as tl
  494. from triton.compiler.compiler import AttrsDescriptor
  495.  
  496. from torch._inductor.runtime import triton_helpers, triton_heuristics
  497. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  498. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  499.  
  500. @triton.jit
  501. def triton_poi_fused__native_batch_norm_legit_functional_relu_9(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
  502. xnumel = 207936
  503. xoffset = tl.program_id(0) * XBLOCK
  504. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  505. xmask = xindex < xnumel
  506. x2 = xindex
  507. x0 = xindex % 64
  508. tmp0 = tl.load(in_ptr0 + (x2), xmask)
  509. tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  510. tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  511. tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
  512. tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
  513. tmp2 = tmp0 - tmp1
  514. tmp4 = 3249.0
  515. tmp5 = tmp3 / tmp4
  516. tmp6 = 1e-05
  517. tmp7 = tmp5 + tmp6
  518. tmp8 = libdevice.rsqrt(tmp7)
  519. tmp9 = tmp2 * tmp8
  520. tmp11 = tmp9 * tmp10
  521. tmp13 = tmp11 + tmp12
  522. tmp14 = triton_helpers.maximum(0, tmp13)
  523. tl.store(out_ptr0 + (x2), tmp14, xmask)
  524.  
  525.  
  526.  
  527.  
  528. # kernel path: /tmp/torchinductor_chilli/vt/cvt4hpbqmxmk4aanytwt546bgwqi4hivd5oal6ks3lpmczefgpzc.py
  529. # Source Nodes: [out_4, out_5, out_6], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  530. # out_4 => add_11, add_14, mul_14, mul_20, rsqrt_2, sub_2, var_mean_2
  531. # out_5 => add_15
  532. # out_6 => relu_2
  533. import triton
  534. import triton.language as tl
  535. from triton.compiler.compiler import AttrsDescriptor
  536.  
  537. from torch._inductor.runtime import triton_helpers, triton_heuristics
  538. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  539. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  540.  
  541. @triton.jit
  542. def triton_poi_fused__native_batch_norm_legit_functional_add_relu_10(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, XBLOCK : tl.constexpr):
  543. xnumel = 207936
  544. xoffset = tl.program_id(0) * XBLOCK
  545. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  546. xmask = xindex < xnumel
  547. x2 = xindex
  548. x0 = xindex % 64
  549. tmp0 = tl.load(in_ptr0 + (x2), xmask)
  550. tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  551. tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  552. tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
  553. tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
  554. tmp14 = tl.load(in_out_ptr0 + (x2), xmask)
  555. tmp2 = tmp0 - tmp1
  556. tmp4 = 3249.0
  557. tmp5 = tmp3 / tmp4
  558. tmp6 = 1e-05
  559. tmp7 = tmp5 + tmp6
  560. tmp8 = libdevice.rsqrt(tmp7)
  561. tmp9 = tmp2 * tmp8
  562. tmp11 = tmp9 * tmp10
  563. tmp13 = tmp11 + tmp12
  564. tmp15 = tmp13 + tmp14
  565. tmp16 = triton_helpers.maximum(0, tmp15)
  566. tl.store(in_out_ptr0 + (x2), tmp16, xmask)
  567.  
  568.  
  569.  
  570.  
  571. # kernel path: /tmp/torchinductor_chilli/2x/c2xz42gsxbn7afzhg3pitrfsomlnmlcqformp7rf27frh7ow6fjy.py
  572. # Source Nodes: [out_14], Original ATen: [aten.convolution]
  573. # out_14 => convolution_5
  574. import triton
  575. import triton.language as tl
  576. from triton.compiler.compiler import AttrsDescriptor
  577.  
  578. from torch._inductor.runtime import triton_helpers, triton_heuristics
  579. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  580. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  581.  
  582. @triton.jit
  583. def triton_poi_fused_convolution_11(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
  584. ynumel = 8192
  585. xnumel = 9
  586. yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
  587. yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
  588. ymask = yindex < ynumel
  589. xoffset = tl.program_id(0) * XBLOCK
  590. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  591. xmask = xindex < xnumel
  592. x2 = xindex
  593. y3 = yindex
  594. y0 = yindex % 64
  595. y1 = (yindex // 64)
  596. tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
  597. tl.store(out_ptr0 + (y0 + (64*x2) + (576*y1)), tmp0, xmask)
  598.  
  599.  
  600.  
  601.  
  602. # kernel path: /tmp/torchinductor_chilli/yx/cyxk35gzq34tsaipo6wzfghuopwljvfsgir35jcovgkosngo2ity.py
  603. # Source Nodes: [out_15], Original ATen: [aten._native_batch_norm_legit_functional]
  604. # out_15 => add_29, add_30, mul_36, mul_37, mul_38, mul_39, mul_40, var_mean_5
  605. import triton
  606. import triton.language as tl
  607. from triton.compiler.compiler import AttrsDescriptor
  608.  
  609. from torch._inductor.runtime import triton_helpers, triton_heuristics
  610. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  611. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  612.  
  613. @triton.jit
  614. def triton_per_fused__native_batch_norm_legit_functional_12(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, rnumel):
  615. xnumel = 128
  616. XBLOCK: tl.constexpr = 1
  617. rnumel = 841
  618. RBLOCK: tl.constexpr = 1024
  619. xoffset = tl.program_id(0) * XBLOCK
  620. xindex = tl.full([1], xoffset, tl.int32)
  621. xmask = xindex < xnumel
  622. rindex = tl.arange(0, RBLOCK)[:]
  623. roffset = 0
  624. rmask = rindex < rnumel
  625. r1 = rindex
  626. x0 = xindex
  627. tmp0 = tl.load(in_ptr0 + (x0 + (128*r1)), rmask & xmask, other=0.0)
  628. tmp19 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  629. tmp28 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  630. tmp1 = tl.broadcast_to(tmp0, [RBLOCK])
  631. tmp3 = tl.where(rmask & xmask, tmp1, 0)
  632. tmp4 = tl.broadcast_to(tmp1, [RBLOCK])
  633. tmp6 = tl.where(rmask & xmask, tmp4, 0)
  634. tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0))
  635. tmp8 = tl.full([1], 841, tl.int32)
  636. tmp9 = tmp8.to(tl.float32)
  637. tmp10 = tmp7 / tmp9
  638. tmp11 = tmp1 - tmp10
  639. tmp12 = tmp11 * tmp11
  640. tmp13 = tl.broadcast_to(tmp12, [RBLOCK])
  641. tmp15 = tl.where(rmask & xmask, tmp13, 0)
  642. tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp15, 0))
  643. tmp17 = 0.1
  644. tmp18 = tmp10 * tmp17
  645. tmp20 = 0.9
  646. tmp21 = tmp19 * tmp20
  647. tmp22 = tmp18 + tmp21
  648. tmp23 = 841.0
  649. tmp24 = tmp16 / tmp23
  650. tmp25 = 1.0011904761904762
  651. tmp26 = tmp24 * tmp25
  652. tmp27 = tmp26 * tmp17
  653. tmp29 = tmp28 * tmp20
  654. tmp30 = tmp27 + tmp29
  655. tl.store(out_ptr3 + (x0), tmp22, xmask)
  656. tl.store(out_ptr5 + (x0), tmp30, xmask)
  657. tl.store(out_ptr0 + (x0), tmp10, xmask)
  658. tl.store(out_ptr1 + (x0), tmp16, xmask)
  659.  
  660.  
  661.  
  662.  
  663. # kernel path: /tmp/torchinductor_chilli/vv/cvvk47b5xgybm23ez23a7raivkf2sv2kr4orbokbjadnck4nft3z.py
  664. # Source Nodes: [out_15, out_16], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  665. # out_15 => add_28, add_31, mul_35, mul_41, rsqrt_5, sub_5, var_mean_5
  666. # out_16 => relu_5
  667. import triton
  668. import triton.language as tl
  669. from triton.compiler.compiler import AttrsDescriptor
  670.  
  671. from torch._inductor.runtime import triton_helpers, triton_heuristics
  672. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  673. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  674.  
  675. @triton.jit
  676. def triton_poi_fused__native_batch_norm_legit_functional_relu_13(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
  677. xnumel = 107648
  678. xoffset = tl.program_id(0) * XBLOCK
  679. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  680. xmask = xindex < xnumel
  681. x2 = xindex
  682. x0 = xindex % 128
  683. tmp0 = tl.load(in_ptr0 + (x2), xmask)
  684. tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  685. tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  686. tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
  687. tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
  688. tmp2 = tmp0 - tmp1
  689. tmp4 = 841.0
  690. tmp5 = tmp3 / tmp4
  691. tmp6 = 1e-05
  692. tmp7 = tmp5 + tmp6
  693. tmp8 = libdevice.rsqrt(tmp7)
  694. tmp9 = tmp2 * tmp8
  695. tmp11 = tmp9 * tmp10
  696. tmp13 = tmp11 + tmp12
  697. tmp14 = triton_helpers.maximum(0, tmp13)
  698. tl.store(out_ptr0 + (x2), tmp14, xmask)
  699.  
  700.  
  701.  
  702.  
  703. # kernel path: /tmp/torchinductor_chilli/a2/ca2vg66dc6tbazekmxs7fptl72wa3nskx3lnbicxijilno3uoule.py
  704. # Source Nodes: [out_15, out_16, out_17], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  705. # out_15 => add_28, add_31, mul_35, mul_41, rsqrt_5, sub_5, var_mean_5
  706. # out_16 => relu_5
  707. # out_17 => convolution_6
  708. import triton
  709. import triton.language as tl
  710. from triton.compiler.compiler import AttrsDescriptor
  711.  
  712. from torch._inductor.runtime import triton_helpers, triton_heuristics
  713. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  714. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  715.  
  716. @triton.jit
  717. def triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_14(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
  718. ynumel = 16384
  719. xnumel = 9
  720. yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
  721. yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
  722. ymask = yindex < ynumel
  723. xoffset = tl.program_id(0) * XBLOCK
  724. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  725. xmask = xindex < xnumel
  726. x2 = xindex
  727. y3 = yindex
  728. y0 = yindex % 128
  729. y1 = (yindex // 128)
  730. tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
  731. tl.store(out_ptr0 + (y0 + (128*x2) + (1152*y1)), tmp0, xmask)
  732.  
  733.  
  734.  
  735.  
  736. # kernel path: /tmp/torchinductor_chilli/6t/c6t2wfxi5lmlky6np3qaojn7zwyvd6xsx545pckbcrchazw4yg5z.py
  737. # Source Nodes: [identity, out_18, out_19, out_20], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  738. # identity => add_38, add_41, mul_49, mul_55, rsqrt_7, sub_7, var_mean_7
  739. # out_18 => add_33, add_36, mul_42, mul_48, rsqrt_6, sub_6, var_mean_6
  740. # out_19 => add_42
  741. # out_20 => relu_6
  742. import triton
  743. import triton.language as tl
  744. from triton.compiler.compiler import AttrsDescriptor
  745.  
  746. from torch._inductor.runtime import triton_helpers, triton_heuristics
  747. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  748. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  749.  
  750. @triton.jit
  751. def triton_poi_fused__native_batch_norm_legit_functional_add_relu_15(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, xnumel, XBLOCK : tl.constexpr):
  752. xnumel = 107648
  753. xoffset = tl.program_id(0) * XBLOCK
  754. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  755. xmask = xindex < xnumel
  756. x2 = xindex
  757. x0 = xindex % 128
  758. tmp0 = tl.load(in_ptr0 + (x2), xmask)
  759. tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  760. tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  761. tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
  762. tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
  763. tmp14 = tl.load(in_ptr5 + (x2), xmask)
  764. tmp15 = tl.load(in_ptr6 + (x0), xmask, eviction_policy='evict_last')
  765. tmp17 = tl.load(in_ptr7 + (x0), xmask, eviction_policy='evict_last')
  766. tmp22 = tl.load(in_ptr8 + (x0), xmask, eviction_policy='evict_last')
  767. tmp24 = tl.load(in_ptr9 + (x0), xmask, eviction_policy='evict_last')
  768. tmp2 = tmp0 - tmp1
  769. tmp4 = 841.0
  770. tmp5 = tmp3 / tmp4
  771. tmp6 = 1e-05
  772. tmp7 = tmp5 + tmp6
  773. tmp8 = libdevice.rsqrt(tmp7)
  774. tmp9 = tmp2 * tmp8
  775. tmp11 = tmp9 * tmp10
  776. tmp13 = tmp11 + tmp12
  777. tmp16 = tmp14 - tmp15
  778. tmp18 = tmp17 / tmp4
  779. tmp19 = tmp18 + tmp6
  780. tmp20 = libdevice.rsqrt(tmp19)
  781. tmp21 = tmp16 * tmp20
  782. tmp23 = tmp21 * tmp22
  783. tmp25 = tmp23 + tmp24
  784. tmp26 = tmp13 + tmp25
  785. tmp27 = triton_helpers.maximum(0, tmp26)
  786. tl.store(in_out_ptr0 + (x2), tmp27, xmask)
  787.  
  788.  
  789.  
  790.  
  791. # kernel path: /tmp/torchinductor_chilli/26/c26rt5bj4lhziddwgxx2yplkvhjuuc7bht32mqdzdeewrv6zwvso.py
  792. # Source Nodes: [out_25, out_26, out_27], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  793. # out_25 => add_49, add_52, mul_63, mul_69, rsqrt_9, sub_9, var_mean_9
  794. # out_26 => add_53
  795. # out_27 => relu_8
  796. import triton
  797. import triton.language as tl
  798. from triton.compiler.compiler import AttrsDescriptor
  799.  
  800. from torch._inductor.runtime import triton_helpers, triton_heuristics
  801. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  802. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  803.  
  804. @triton.jit
  805. def triton_poi_fused__native_batch_norm_legit_functional_add_relu_16(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, XBLOCK : tl.constexpr):
  806. xnumel = 107648
  807. xoffset = tl.program_id(0) * XBLOCK
  808. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  809. xmask = xindex < xnumel
  810. x2 = xindex
  811. x0 = xindex % 128
  812. tmp0 = tl.load(in_ptr0 + (x2), xmask)
  813. tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  814. tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  815. tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
  816. tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
  817. tmp14 = tl.load(in_out_ptr0 + (x2), xmask)
  818. tmp2 = tmp0 - tmp1
  819. tmp4 = 841.0
  820. tmp5 = tmp3 / tmp4
  821. tmp6 = 1e-05
  822. tmp7 = tmp5 + tmp6
  823. tmp8 = libdevice.rsqrt(tmp7)
  824. tmp9 = tmp2 * tmp8
  825. tmp11 = tmp9 * tmp10
  826. tmp13 = tmp11 + tmp12
  827. tmp15 = tmp13 + tmp14
  828. tmp16 = triton_helpers.maximum(0, tmp15)
  829. tl.store(in_out_ptr0 + (x2), tmp16, xmask)
  830.  
  831.  
  832.  
  833.  
  834. # kernel path: /tmp/torchinductor_chilli/z2/cz2hmaiqkwsixgjuxaqmaywuxpi2lcassripzkjabnb4j44mcjed.py
  835. # Source Nodes: [out_28], Original ATen: [aten.convolution]
  836. # out_28 => convolution_10
  837. import triton
  838. import triton.language as tl
  839. from triton.compiler.compiler import AttrsDescriptor
  840.  
  841. from torch._inductor.runtime import triton_helpers, triton_heuristics
  842. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  843. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  844.  
  845. @triton.jit
  846. def triton_poi_fused_convolution_17(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
  847. ynumel = 32768
  848. xnumel = 9
  849. yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
  850. yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
  851. ymask = yindex < ynumel
  852. xoffset = tl.program_id(0) * XBLOCK
  853. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  854. xmask = xindex < xnumel
  855. x2 = xindex
  856. y3 = yindex
  857. y0 = yindex % 128
  858. y1 = (yindex // 128)
  859. tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
  860. tl.store(out_ptr0 + (y0 + (128*x2) + (1152*y1)), tmp0, xmask)
  861.  
  862.  
  863.  
  864.  
  865. # kernel path: /tmp/torchinductor_chilli/5m/c5muflxhnwscn5dsarqbe5g35q5b3e7tjk5aamn42344ztmgffpc.py
  866. # Source Nodes: [out_29], Original ATen: [aten._native_batch_norm_legit_functional]
  867. # out_29 => add_56, add_57, mul_71, mul_72, mul_73, mul_74, mul_75, var_mean_10
  868. import triton
  869. import triton.language as tl
  870. from triton.compiler.compiler import AttrsDescriptor
  871.  
  872. from torch._inductor.runtime import triton_helpers, triton_heuristics
  873. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  874. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  875.  
  876. @triton.jit
  877. def triton_per_fused__native_batch_norm_legit_functional_18(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, rnumel, XBLOCK : tl.constexpr):
  878. xnumel = 256
  879. rnumel = 225
  880. RBLOCK: tl.constexpr = 256
  881. xoffset = tl.program_id(0) * XBLOCK
  882. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  883. xmask = xindex < xnumel
  884. rindex = tl.arange(0, RBLOCK)[None, :]
  885. roffset = 0
  886. rmask = rindex < rnumel
  887. r1 = rindex
  888. x0 = xindex
  889. tmp0 = tl.load(in_ptr0 + (x0 + (256*r1)), rmask & xmask, other=0.0)
  890. tmp19 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  891. tmp28 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  892. tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
  893. tmp3 = tl.where(rmask & xmask, tmp1, 0)
  894. tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
  895. tmp6 = tl.where(rmask & xmask, tmp4, 0)
  896. tmp7 = tl.sum(tmp6, 1)[:, None]
  897. tmp8 = tl.full([XBLOCK, 1], 225, tl.int32)
  898. tmp9 = tmp8.to(tl.float32)
  899. tmp10 = tmp7 / tmp9
  900. tmp11 = tmp1 - tmp10
  901. tmp12 = tmp11 * tmp11
  902. tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
  903. tmp15 = tl.where(rmask & xmask, tmp13, 0)
  904. tmp16 = tl.sum(tmp15, 1)[:, None]
  905. tmp17 = 0.1
  906. tmp18 = tmp10 * tmp17
  907. tmp20 = 0.9
  908. tmp21 = tmp19 * tmp20
  909. tmp22 = tmp18 + tmp21
  910. tmp23 = 225.0
  911. tmp24 = tmp16 / tmp23
  912. tmp25 = 1.0044642857142858
  913. tmp26 = tmp24 * tmp25
  914. tmp27 = tmp26 * tmp17
  915. tmp29 = tmp28 * tmp20
  916. tmp30 = tmp27 + tmp29
  917. tl.store(out_ptr3 + (x0), tmp22, xmask)
  918. tl.store(out_ptr5 + (x0), tmp30, xmask)
  919. tl.store(out_ptr0 + (x0), tmp10, xmask)
  920. tl.store(out_ptr1 + (x0), tmp16, xmask)
  921.  
  922.  
  923.  
  924.  
  925. # kernel path: /tmp/torchinductor_chilli/o6/co6ycddj5sgoncwk3rdnpd65b5kx2eufjx4obfxbpg7ev3hii5ti.py
  926. # Source Nodes: [out_29, out_30], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  927. # out_29 => add_55, add_58, mul_70, mul_76, rsqrt_10, sub_10, var_mean_10
  928. # out_30 => relu_9
  929. import triton
  930. import triton.language as tl
  931. from triton.compiler.compiler import AttrsDescriptor
  932.  
  933. from torch._inductor.runtime import triton_helpers, triton_heuristics
  934. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  935. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  936.  
  937. @triton.jit
  938. def triton_poi_fused__native_batch_norm_legit_functional_relu_19(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
  939. xnumel = 57600
  940. xoffset = tl.program_id(0) * XBLOCK
  941. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  942. xmask = xindex < xnumel
  943. x2 = xindex
  944. x0 = xindex % 256
  945. tmp0 = tl.load(in_ptr0 + (x2), xmask)
  946. tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  947. tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  948. tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
  949. tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
  950. tmp2 = tmp0 - tmp1
  951. tmp4 = 225.0
  952. tmp5 = tmp3 / tmp4
  953. tmp6 = 1e-05
  954. tmp7 = tmp5 + tmp6
  955. tmp8 = libdevice.rsqrt(tmp7)
  956. tmp9 = tmp2 * tmp8
  957. tmp11 = tmp9 * tmp10
  958. tmp13 = tmp11 + tmp12
  959. tmp14 = triton_helpers.maximum(0, tmp13)
  960. tl.store(out_ptr0 + (x2), tmp14, xmask)
  961.  
  962.  
  963.  
  964.  
  965. # kernel path: /tmp/torchinductor_chilli/nf/cnfgqtfbw5smkkagwken4rsi6rzry342age3cxo5oozd2ewyyhez.py
  966. # Source Nodes: [out_29, out_30, out_31], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  967. # out_29 => add_55, add_58, mul_70, mul_76, rsqrt_10, sub_10, var_mean_10
  968. # out_30 => relu_9
  969. # out_31 => convolution_11
  970. import triton
  971. import triton.language as tl
  972. from triton.compiler.compiler import AttrsDescriptor
  973.  
  974. from torch._inductor.runtime import triton_helpers, triton_heuristics
  975. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  976. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  977.  
  978. @triton.jit
  979. def triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_20(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
  980. ynumel = 65536
  981. xnumel = 9
  982. yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
  983. yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
  984. ymask = yindex < ynumel
  985. xoffset = tl.program_id(0) * XBLOCK
  986. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  987. xmask = xindex < xnumel
  988. x2 = xindex
  989. y3 = yindex
  990. y0 = yindex % 256
  991. y1 = (yindex // 256)
  992. tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
  993. tl.store(out_ptr0 + (y0 + (256*x2) + (2304*y1)), tmp0, xmask)
  994.  
  995.  
  996.  
  997.  
  998. # kernel path: /tmp/torchinductor_chilli/ti/ctizppkxqrjhhsm7luucj6whivyl2vv4fjjdpas7rporotca2pry.py
  999. # Source Nodes: [identity_1, out_32, out_33, out_34], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  1000. # identity_1 => add_65, add_68, mul_84, mul_90, rsqrt_12, sub_12, var_mean_12
  1001. # out_32 => add_60, add_63, mul_77, mul_83, rsqrt_11, sub_11, var_mean_11
  1002. # out_33 => add_69
  1003. # out_34 => relu_10
  1004. import triton
  1005. import triton.language as tl
  1006. from triton.compiler.compiler import AttrsDescriptor
  1007.  
  1008. from torch._inductor.runtime import triton_helpers, triton_heuristics
  1009. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  1010. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  1011.  
  1012. @triton.jit
  1013. def triton_poi_fused__native_batch_norm_legit_functional_add_relu_21(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, xnumel, XBLOCK : tl.constexpr):
  1014. xnumel = 57600
  1015. xoffset = tl.program_id(0) * XBLOCK
  1016. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  1017. xmask = xindex < xnumel
  1018. x2 = xindex
  1019. x0 = xindex % 256
  1020. tmp0 = tl.load(in_ptr0 + (x2), xmask)
  1021. tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  1022. tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  1023. tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
  1024. tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
  1025. tmp14 = tl.load(in_ptr5 + (x2), xmask)
  1026. tmp15 = tl.load(in_ptr6 + (x0), xmask, eviction_policy='evict_last')
  1027. tmp17 = tl.load(in_ptr7 + (x0), xmask, eviction_policy='evict_last')
  1028. tmp22 = tl.load(in_ptr8 + (x0), xmask, eviction_policy='evict_last')
  1029. tmp24 = tl.load(in_ptr9 + (x0), xmask, eviction_policy='evict_last')
  1030. tmp2 = tmp0 - tmp1
  1031. tmp4 = 225.0
  1032. tmp5 = tmp3 / tmp4
  1033. tmp6 = 1e-05
  1034. tmp7 = tmp5 + tmp6
  1035. tmp8 = libdevice.rsqrt(tmp7)
  1036. tmp9 = tmp2 * tmp8
  1037. tmp11 = tmp9 * tmp10
  1038. tmp13 = tmp11 + tmp12
  1039. tmp16 = tmp14 - tmp15
  1040. tmp18 = tmp17 / tmp4
  1041. tmp19 = tmp18 + tmp6
  1042. tmp20 = libdevice.rsqrt(tmp19)
  1043. tmp21 = tmp16 * tmp20
  1044. tmp23 = tmp21 * tmp22
  1045. tmp25 = tmp23 + tmp24
  1046. tmp26 = tmp13 + tmp25
  1047. tmp27 = triton_helpers.maximum(0, tmp26)
  1048. tl.store(in_out_ptr0 + (x2), tmp27, xmask)
  1049.  
  1050.  
  1051.  
  1052.  
  1053. # kernel path: /tmp/torchinductor_chilli/h4/ch4hlqd3tqqqn4auweqr3lccybiiv5c6i2hmtylkhw3i3ac6nvem.py
  1054. # Source Nodes: [out_39, out_40, out_41], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  1055. # out_39 => add_76, add_79, mul_104, mul_98, rsqrt_14, sub_14, var_mean_14
  1056. # out_40 => add_80
  1057. # out_41 => relu_12
  1058. import triton
  1059. import triton.language as tl
  1060. from triton.compiler.compiler import AttrsDescriptor
  1061.  
  1062. from torch._inductor.runtime import triton_helpers, triton_heuristics
  1063. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  1064. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  1065.  
  1066. @triton.jit
  1067. def triton_poi_fused__native_batch_norm_legit_functional_add_relu_22(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, XBLOCK : tl.constexpr):
  1068. xnumel = 57600
  1069. xoffset = tl.program_id(0) * XBLOCK
  1070. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  1071. xmask = xindex < xnumel
  1072. x2 = xindex
  1073. x0 = xindex % 256
  1074. tmp0 = tl.load(in_ptr0 + (x2), xmask)
  1075. tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  1076. tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  1077. tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
  1078. tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
  1079. tmp14 = tl.load(in_out_ptr0 + (x2), xmask)
  1080. tmp2 = tmp0 - tmp1
  1081. tmp4 = 225.0
  1082. tmp5 = tmp3 / tmp4
  1083. tmp6 = 1e-05
  1084. tmp7 = tmp5 + tmp6
  1085. tmp8 = libdevice.rsqrt(tmp7)
  1086. tmp9 = tmp2 * tmp8
  1087. tmp11 = tmp9 * tmp10
  1088. tmp13 = tmp11 + tmp12
  1089. tmp15 = tmp13 + tmp14
  1090. tmp16 = triton_helpers.maximum(0, tmp15)
  1091. tl.store(in_out_ptr0 + (x2), tmp16, xmask)
  1092.  
  1093.  
  1094.  
  1095.  
  1096. # kernel path: /tmp/torchinductor_chilli/aw/cawomzmzglbre7ilhcarrqi7hjfndlnawqmuj6pdrjwpbvtwu67i.py
  1097. # Source Nodes: [out_42], Original ATen: [aten.convolution]
  1098. # out_42 => convolution_15
  1099. import triton
  1100. import triton.language as tl
  1101. from triton.compiler.compiler import AttrsDescriptor
  1102.  
  1103. from torch._inductor.runtime import triton_helpers, triton_heuristics
  1104. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  1105. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  1106.  
  1107. @triton.jit
  1108. def triton_poi_fused_convolution_23(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
  1109. ynumel = 131072
  1110. xnumel = 9
  1111. yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
  1112. yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
  1113. ymask = yindex < ynumel
  1114. xoffset = tl.program_id(0) * XBLOCK
  1115. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  1116. xmask = xindex < xnumel
  1117. x2 = xindex
  1118. y3 = yindex
  1119. y0 = yindex % 256
  1120. y1 = (yindex // 256)
  1121. tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
  1122. tl.store(out_ptr0 + (y0 + (256*x2) + (2304*y1)), tmp0, xmask)
  1123.  
  1124.  
  1125.  
  1126.  
  1127. # kernel path: /tmp/torchinductor_chilli/fl/cflwvwmk3cuzfonagfvse3fly3w2z6wozgcbuc5yjwib3an7gogg.py
  1128. # Source Nodes: [out_43], Original ATen: [aten._native_batch_norm_legit_functional]
  1129. # out_43 => add_83, add_84, mul_106, mul_107, mul_108, mul_109, mul_110, var_mean_15
  1130. import triton
  1131. import triton.language as tl
  1132. from triton.compiler.compiler import AttrsDescriptor
  1133.  
  1134. from torch._inductor.runtime import triton_helpers, triton_heuristics
  1135. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  1136. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  1137.  
  1138. @triton.jit
  1139. def triton_per_fused__native_batch_norm_legit_functional_24(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, rnumel, XBLOCK : tl.constexpr):
  1140. xnumel = 512
  1141. rnumel = 64
  1142. RBLOCK: tl.constexpr = 64
  1143. xoffset = tl.program_id(0) * XBLOCK
  1144. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  1145. xmask = xindex < xnumel
  1146. rindex = tl.arange(0, RBLOCK)[None, :]
  1147. roffset = 0
  1148. rmask = rindex < rnumel
  1149. r1 = rindex
  1150. x0 = xindex
  1151. tmp0 = tl.load(in_ptr0 + (x0 + (512*r1)), rmask & xmask, other=0.0)
  1152. tmp19 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  1153. tmp28 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  1154. tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
  1155. tmp3 = tl.where(rmask & xmask, tmp1, 0)
  1156. tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
  1157. tmp6 = tl.where(rmask & xmask, tmp4, 0)
  1158. tmp7 = tl.sum(tmp6, 1)[:, None]
  1159. tmp8 = tl.full([XBLOCK, 1], 64, tl.int32)
  1160. tmp9 = tmp8.to(tl.float32)
  1161. tmp10 = tmp7 / tmp9
  1162. tmp11 = tmp1 - tmp10
  1163. tmp12 = tmp11 * tmp11
  1164. tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
  1165. tmp15 = tl.where(rmask & xmask, tmp13, 0)
  1166. tmp16 = tl.sum(tmp15, 1)[:, None]
  1167. tmp17 = 0.1
  1168. tmp18 = tmp10 * tmp17
  1169. tmp20 = 0.9
  1170. tmp21 = tmp19 * tmp20
  1171. tmp22 = tmp18 + tmp21
  1172. tmp23 = 64.0
  1173. tmp24 = tmp16 / tmp23
  1174. tmp25 = 1.0158730158730158
  1175. tmp26 = tmp24 * tmp25
  1176. tmp27 = tmp26 * tmp17
  1177. tmp29 = tmp28 * tmp20
  1178. tmp30 = tmp27 + tmp29
  1179. tl.store(out_ptr3 + (x0), tmp22, xmask)
  1180. tl.store(out_ptr5 + (x0), tmp30, xmask)
  1181. tl.store(out_ptr0 + (x0), tmp10, xmask)
  1182. tl.store(out_ptr1 + (x0), tmp16, xmask)
  1183.  
  1184.  
  1185.  
  1186.  
  1187. # kernel path: /tmp/torchinductor_chilli/sy/csyscyubm24wopyxzglh7ys4ywxr3iuludm3tpzws4m5xt4glbd2.py
  1188. # Source Nodes: [out_43, out_44], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  1189. # out_43 => add_82, add_85, mul_105, mul_111, rsqrt_15, sub_15, var_mean_15
  1190. # out_44 => relu_13
  1191. import triton
  1192. import triton.language as tl
  1193. from triton.compiler.compiler import AttrsDescriptor
  1194.  
  1195. from torch._inductor.runtime import triton_helpers, triton_heuristics
  1196. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  1197. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  1198.  
  1199. @triton.jit
  1200. def triton_poi_fused__native_batch_norm_legit_functional_relu_25(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
  1201. xnumel = 32768
  1202. xoffset = tl.program_id(0) * XBLOCK
  1203. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  1204. xmask = xindex < xnumel
  1205. x2 = xindex
  1206. x0 = xindex % 512
  1207. tmp0 = tl.load(in_ptr0 + (x2), None)
  1208. tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last')
  1209. tmp3 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
  1210. tmp10 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
  1211. tmp12 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
  1212. tmp2 = tmp0 - tmp1
  1213. tmp4 = 64.0
  1214. tmp5 = tmp3 / tmp4
  1215. tmp6 = 1e-05
  1216. tmp7 = tmp5 + tmp6
  1217. tmp8 = libdevice.rsqrt(tmp7)
  1218. tmp9 = tmp2 * tmp8
  1219. tmp11 = tmp9 * tmp10
  1220. tmp13 = tmp11 + tmp12
  1221. tmp14 = triton_helpers.maximum(0, tmp13)
  1222. tl.store(out_ptr0 + (x2), tmp14, None)
  1223.  
  1224.  
  1225.  
  1226.  
  1227. # kernel path: /tmp/torchinductor_chilli/zo/czoc4yry3r2czazn4sagx7wsfwtp3hosm7nwr7fekbym23kvtnfu.py
  1228. # Source Nodes: [out_43, out_44, out_45], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1229. # out_43 => add_82, add_85, mul_105, mul_111, rsqrt_15, sub_15, var_mean_15
  1230. # out_44 => relu_13
  1231. # out_45 => convolution_16
  1232. import triton
  1233. import triton.language as tl
  1234. from triton.compiler.compiler import AttrsDescriptor
  1235.  
  1236. from torch._inductor.runtime import triton_helpers, triton_heuristics
  1237. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  1238. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  1239.  
  1240. @triton.jit
  1241. def triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_26(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
  1242. ynumel = 262144
  1243. xnumel = 9
  1244. yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
  1245. yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
  1246. ymask = yindex < ynumel
  1247. xoffset = tl.program_id(0) * XBLOCK
  1248. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  1249. xmask = xindex < xnumel
  1250. x2 = xindex
  1251. y3 = yindex
  1252. y0 = yindex % 512
  1253. y1 = (yindex // 512)
  1254. tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
  1255. tl.store(out_ptr0 + (y0 + (512*x2) + (4608*y1)), tmp0, xmask)
  1256.  
  1257.  
  1258.  
  1259.  
  1260. # kernel path: /tmp/torchinductor_chilli/x3/cx3julbo4ui5eikjj253cy76wq3pd3umd6aj6cjvz5lg7zerlerv.py
  1261. # Source Nodes: [identity_2, out_46, out_47, out_48], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  1262. # identity_2 => add_92, add_95, mul_119, mul_125, rsqrt_17, sub_17, var_mean_17
  1263. # out_46 => add_87, add_90, mul_112, mul_118, rsqrt_16, sub_16, var_mean_16
  1264. # out_47 => add_96
  1265. # out_48 => relu_14
  1266. import triton
  1267. import triton.language as tl
  1268. from triton.compiler.compiler import AttrsDescriptor
  1269.  
  1270. from torch._inductor.runtime import triton_helpers, triton_heuristics
  1271. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  1272. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  1273.  
  1274. @triton.jit
  1275. def triton_poi_fused__native_batch_norm_legit_functional_add_relu_27(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, xnumel, XBLOCK : tl.constexpr):
  1276. xnumel = 32768
  1277. xoffset = tl.program_id(0) * XBLOCK
  1278. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  1279. xmask = xindex < xnumel
  1280. x2 = xindex
  1281. x0 = xindex % 512
  1282. tmp0 = tl.load(in_ptr0 + (x2), None)
  1283. tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last')
  1284. tmp3 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
  1285. tmp10 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
  1286. tmp12 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
  1287. tmp14 = tl.load(in_ptr5 + (x2), None)
  1288. tmp15 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
  1289. tmp17 = tl.load(in_ptr7 + (x0), None, eviction_policy='evict_last')
  1290. tmp22 = tl.load(in_ptr8 + (x0), None, eviction_policy='evict_last')
  1291. tmp24 = tl.load(in_ptr9 + (x0), None, eviction_policy='evict_last')
  1292. tmp2 = tmp0 - tmp1
  1293. tmp4 = 64.0
  1294. tmp5 = tmp3 / tmp4
  1295. tmp6 = 1e-05
  1296. tmp7 = tmp5 + tmp6
  1297. tmp8 = libdevice.rsqrt(tmp7)
  1298. tmp9 = tmp2 * tmp8
  1299. tmp11 = tmp9 * tmp10
  1300. tmp13 = tmp11 + tmp12
  1301. tmp16 = tmp14 - tmp15
  1302. tmp18 = tmp17 / tmp4
  1303. tmp19 = tmp18 + tmp6
  1304. tmp20 = libdevice.rsqrt(tmp19)
  1305. tmp21 = tmp16 * tmp20
  1306. tmp23 = tmp21 * tmp22
  1307. tmp25 = tmp23 + tmp24
  1308. tmp26 = tmp13 + tmp25
  1309. tmp27 = triton_helpers.maximum(0, tmp26)
  1310. tl.store(in_out_ptr0 + (x2), tmp27, None)
  1311.  
  1312.  
  1313.  
  1314.  
  1315. # kernel path: /tmp/torchinductor_chilli/mr/cmr7ut36zrwmn4s73ifiyzecupyoozfkyvscgohgilbdwcv6vdgv.py
  1316. # Source Nodes: [out_53, out_54, out_55, x_4], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.mean, aten.relu]
  1317. # out_53 => add_103, add_104, add_105, add_106, mul_133, mul_134, mul_135, mul_136, mul_137, mul_138, mul_139, rsqrt_19, sub_19, var_mean_19
  1318. # out_54 => add_107
  1319. # out_55 => relu_16
  1320. # x_4 => mean
  1321. import triton
  1322. import triton.language as tl
  1323. from triton.compiler.compiler import AttrsDescriptor
  1324.  
  1325. from torch._inductor.runtime import triton_helpers, triton_heuristics
  1326. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  1327. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  1328.  
  1329. @triton.jit
  1330. def triton_per_fused__native_batch_norm_legit_functional_add_mean_relu_28(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr3, out_ptr5, xnumel, rnumel, XBLOCK : tl.constexpr):
  1331. xnumel = 512
  1332. rnumel = 64
  1333. RBLOCK: tl.constexpr = 64
  1334. xoffset = tl.program_id(0) * XBLOCK
  1335. xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
  1336. xmask = xindex < xnumel
  1337. rindex = tl.arange(0, RBLOCK)[None, :]
  1338. roffset = 0
  1339. rmask = rindex < rnumel
  1340. r1 = rindex
  1341. x0 = xindex
  1342. tmp0 = tl.load(in_ptr0 + (x0 + (512*r1)), rmask & xmask, other=0.0)
  1343. tmp24 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
  1344. tmp26 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
  1345. tmp28 = tl.load(in_ptr3 + (x0 + (512*r1)), rmask & xmask, other=0.0)
  1346. tmp38 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
  1347. tmp45 = tl.load(in_ptr5 + (x0), xmask, eviction_policy='evict_last')
  1348. tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
  1349. tmp3 = tl.where(rmask & xmask, tmp1, 0)
  1350. tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
  1351. tmp6 = tl.where(rmask & xmask, tmp4, 0)
  1352. tmp7 = tl.sum(tmp6, 1)[:, None]
  1353. tmp8 = tl.full([XBLOCK, 1], 64, tl.int32)
  1354. tmp9 = tmp8.to(tl.float32)
  1355. tmp10 = tmp7 / tmp9
  1356. tmp11 = tmp1 - tmp10
  1357. tmp12 = tmp11 * tmp11
  1358. tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
  1359. tmp15 = tl.where(rmask & xmask, tmp13, 0)
  1360. tmp16 = tl.sum(tmp15, 1)[:, None]
  1361. tmp17 = tmp0 - tmp10
  1362. tmp18 = 64.0
  1363. tmp19 = tmp16 / tmp18
  1364. tmp20 = 1e-05
  1365. tmp21 = tmp19 + tmp20
  1366. tmp22 = libdevice.rsqrt(tmp21)
  1367. tmp23 = tmp17 * tmp22
  1368. tmp25 = tmp23 * tmp24
  1369. tmp27 = tmp25 + tmp26
  1370. tmp29 = tmp27 + tmp28
  1371. tmp30 = triton_helpers.maximum(0, tmp29)
  1372. tmp31 = tl.broadcast_to(tmp30, [XBLOCK, RBLOCK])
  1373. tmp33 = tl.where(rmask & xmask, tmp31, 0)
  1374. tmp34 = tl.sum(tmp33, 1)[:, None]
  1375. tmp35 = tmp34 / tmp18
  1376. tmp36 = 0.1
  1377. tmp37 = tmp10 * tmp36
  1378. tmp39 = 0.9
  1379. tmp40 = tmp38 * tmp39
  1380. tmp41 = tmp37 + tmp40
  1381. tmp42 = 1.0158730158730158
  1382. tmp43 = tmp19 * tmp42
  1383. tmp44 = tmp43 * tmp36
  1384. tmp46 = tmp45 * tmp39
  1385. tmp47 = tmp44 + tmp46
  1386. tl.debug_barrier()
  1387. tl.store(in_out_ptr0 + (x0), tmp35, xmask)
  1388. tl.store(out_ptr3 + (x0), tmp41, xmask)
  1389. tl.store(out_ptr5 + (x0), tmp47, xmask)
  1390.  
  1391.  
  1392.  
  1393.  
  1394. # kernel path: /tmp/torchinductor_chilli/nj/cnj4bf4ulsnoqw46btyaeowtw6gfffjthwhbb7vulfzsoz43jkdn.py
  1395. # Source Nodes: [x_1], Original ATen: [aten.add]
  1396. # x_1 => add
  1397. import triton
  1398. import triton.language as tl
  1399. from triton.compiler.compiler import AttrsDescriptor
  1400.  
  1401. from torch._inductor.runtime import triton_helpers, triton_heuristics
  1402. from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
  1403. from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
  1404.  
  1405. @triton.jit
  1406. def triton_poi_fused_add_29(in_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
  1407. xnumel = 1
  1408. xoffset = tl.program_id(0) * XBLOCK
  1409. xindex = xoffset + tl.arange(0, XBLOCK)[:]
  1410. xmask = xindex < xnumel
  1411. tmp0 = tl.load(in_ptr0 + (0))
  1412. tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
  1413. tmp2 = tl.full([1], 1, tl.int64)
  1414. tmp3 = tmp1 + tmp2
  1415. tl.store(out_ptr1 + (tl.full([XBLOCK], 0, tl.int32)), tmp3, None)
  1416.  
  1417.  
  1418.  
  1419.  
  1420.  
  1421.  
  1422.  
  1423. def call(args):
  1424. arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1 = args
  1425. args.clear()
  1426. assert_size_stride(arg0_1, (64, 3, 7, 7), (147, 49, 7, 1))
  1427. assert_size_stride(arg1_1, (64, ), (1, ))
  1428. assert_size_stride(arg2_1, (64, ), (1, ))
  1429. assert_size_stride(arg3_1, (64, 64, 3, 3), (576, 9, 3, 1))
  1430. assert_size_stride(arg4_1, (64, ), (1, ))
  1431. assert_size_stride(arg5_1, (64, ), (1, ))
  1432. assert_size_stride(arg6_1, (64, 64, 3, 3), (576, 9, 3, 1))
  1433. assert_size_stride(arg7_1, (64, ), (1, ))
  1434. assert_size_stride(arg8_1, (64, ), (1, ))
  1435. assert_size_stride(arg9_1, (64, 64, 3, 3), (576, 9, 3, 1))
  1436. assert_size_stride(arg10_1, (64, ), (1, ))
  1437. assert_size_stride(arg11_1, (64, ), (1, ))
  1438. assert_size_stride(arg12_1, (64, 64, 3, 3), (576, 9, 3, 1))
  1439. assert_size_stride(arg13_1, (64, ), (1, ))
  1440. assert_size_stride(arg14_1, (64, ), (1, ))
  1441. assert_size_stride(arg15_1, (128, 64, 3, 3), (576, 9, 3, 1))
  1442. assert_size_stride(arg16_1, (128, ), (1, ))
  1443. assert_size_stride(arg17_1, (128, ), (1, ))
  1444. assert_size_stride(arg18_1, (128, 128, 3, 3), (1152, 9, 3, 1))
  1445. assert_size_stride(arg19_1, (128, ), (1, ))
  1446. assert_size_stride(arg20_1, (128, ), (1, ))
  1447. assert_size_stride(arg21_1, (128, 64, 1, 1), (64, 1, 1, 1))
  1448. assert_size_stride(arg22_1, (128, ), (1, ))
  1449. assert_size_stride(arg23_1, (128, ), (1, ))
  1450. assert_size_stride(arg24_1, (128, 128, 3, 3), (1152, 9, 3, 1))
  1451. assert_size_stride(arg25_1, (128, ), (1, ))
  1452. assert_size_stride(arg26_1, (128, ), (1, ))
  1453. assert_size_stride(arg27_1, (128, 128, 3, 3), (1152, 9, 3, 1))
  1454. assert_size_stride(arg28_1, (128, ), (1, ))
  1455. assert_size_stride(arg29_1, (128, ), (1, ))
  1456. assert_size_stride(arg30_1, (256, 128, 3, 3), (1152, 9, 3, 1))
  1457. assert_size_stride(arg31_1, (256, ), (1, ))
  1458. assert_size_stride(arg32_1, (256, ), (1, ))
  1459. assert_size_stride(arg33_1, (256, 256, 3, 3), (2304, 9, 3, 1))
  1460. assert_size_stride(arg34_1, (256, ), (1, ))
  1461. assert_size_stride(arg35_1, (256, ), (1, ))
  1462. assert_size_stride(arg36_1, (256, 128, 1, 1), (128, 1, 1, 1))
  1463. assert_size_stride(arg37_1, (256, ), (1, ))
  1464. assert_size_stride(arg38_1, (256, ), (1, ))
  1465. assert_size_stride(arg39_1, (256, 256, 3, 3), (2304, 9, 3, 1))
  1466. assert_size_stride(arg40_1, (256, ), (1, ))
  1467. assert_size_stride(arg41_1, (256, ), (1, ))
  1468. assert_size_stride(arg42_1, (256, 256, 3, 3), (2304, 9, 3, 1))
  1469. assert_size_stride(arg43_1, (256, ), (1, ))
  1470. assert_size_stride(arg44_1, (256, ), (1, ))
  1471. assert_size_stride(arg45_1, (512, 256, 3, 3), (2304, 9, 3, 1))
  1472. assert_size_stride(arg46_1, (512, ), (1, ))
  1473. assert_size_stride(arg47_1, (512, ), (1, ))
  1474. assert_size_stride(arg48_1, (512, 512, 3, 3), (4608, 9, 3, 1))
  1475. assert_size_stride(arg49_1, (512, ), (1, ))
  1476. assert_size_stride(arg50_1, (512, ), (1, ))
  1477. assert_size_stride(arg51_1, (512, 256, 1, 1), (256, 1, 1, 1))
  1478. assert_size_stride(arg52_1, (512, ), (1, ))
  1479. assert_size_stride(arg53_1, (512, ), (1, ))
  1480. assert_size_stride(arg54_1, (512, 512, 3, 3), (4608, 9, 3, 1))
  1481. assert_size_stride(arg55_1, (512, ), (1, ))
  1482. assert_size_stride(arg56_1, (512, ), (1, ))
  1483. assert_size_stride(arg57_1, (512, 512, 3, 3), (4608, 9, 3, 1))
  1484. assert_size_stride(arg58_1, (512, ), (1, ))
  1485. assert_size_stride(arg59_1, (512, ), (1, ))
  1486. assert_size_stride(arg60_1, (1000, 512), (512, 1))
  1487. assert_size_stride(arg61_1, (1000, ), (1, ))
  1488. assert_size_stride(arg62_1, (64, ), (1, ))
  1489. assert_size_stride(arg63_1, (64, ), (1, ))
  1490. assert_size_stride(arg64_1, (), ())
  1491. assert_size_stride(arg65_1, (64, ), (1, ))
  1492. assert_size_stride(arg66_1, (64, ), (1, ))
  1493. assert_size_stride(arg67_1, (), ())
  1494. assert_size_stride(arg68_1, (64, ), (1, ))
  1495. assert_size_stride(arg69_1, (64, ), (1, ))
  1496. assert_size_stride(arg70_1, (), ())
  1497. assert_size_stride(arg71_1, (64, ), (1, ))
  1498. assert_size_stride(arg72_1, (64, ), (1, ))
  1499. assert_size_stride(arg73_1, (), ())
  1500. assert_size_stride(arg74_1, (64, ), (1, ))
  1501. assert_size_stride(arg75_1, (64, ), (1, ))
  1502. assert_size_stride(arg76_1, (), ())
  1503. assert_size_stride(arg77_1, (128, ), (1, ))
  1504. assert_size_stride(arg78_1, (128, ), (1, ))
  1505. assert_size_stride(arg79_1, (), ())
  1506. assert_size_stride(arg80_1, (128, ), (1, ))
  1507. assert_size_stride(arg81_1, (128, ), (1, ))
  1508. assert_size_stride(arg82_1, (), ())
  1509. assert_size_stride(arg83_1, (128, ), (1, ))
  1510. assert_size_stride(arg84_1, (128, ), (1, ))
  1511. assert_size_stride(arg85_1, (), ())
  1512. assert_size_stride(arg86_1, (128, ), (1, ))
  1513. assert_size_stride(arg87_1, (128, ), (1, ))
  1514. assert_size_stride(arg88_1, (), ())
  1515. assert_size_stride(arg89_1, (128, ), (1, ))
  1516. assert_size_stride(arg90_1, (128, ), (1, ))
  1517. assert_size_stride(arg91_1, (), ())
  1518. assert_size_stride(arg92_1, (256, ), (1, ))
  1519. assert_size_stride(arg93_1, (256, ), (1, ))
  1520. assert_size_stride(arg94_1, (), ())
  1521. assert_size_stride(arg95_1, (256, ), (1, ))
  1522. assert_size_stride(arg96_1, (256, ), (1, ))
  1523. assert_size_stride(arg97_1, (), ())
  1524. assert_size_stride(arg98_1, (256, ), (1, ))
  1525. assert_size_stride(arg99_1, (256, ), (1, ))
  1526. assert_size_stride(arg100_1, (), ())
  1527. assert_size_stride(arg101_1, (256, ), (1, ))
  1528. assert_size_stride(arg102_1, (256, ), (1, ))
  1529. assert_size_stride(arg103_1, (), ())
  1530. assert_size_stride(arg104_1, (256, ), (1, ))
  1531. assert_size_stride(arg105_1, (256, ), (1, ))
  1532. assert_size_stride(arg106_1, (), ())
  1533. assert_size_stride(arg107_1, (512, ), (1, ))
  1534. assert_size_stride(arg108_1, (512, ), (1, ))
  1535. assert_size_stride(arg109_1, (), ())
  1536. assert_size_stride(arg110_1, (512, ), (1, ))
  1537. assert_size_stride(arg111_1, (512, ), (1, ))
  1538. assert_size_stride(arg112_1, (), ())
  1539. assert_size_stride(arg113_1, (512, ), (1, ))
  1540. assert_size_stride(arg114_1, (512, ), (1, ))
  1541. assert_size_stride(arg115_1, (), ())
  1542. assert_size_stride(arg116_1, (512, ), (1, ))
  1543. assert_size_stride(arg117_1, (512, ), (1, ))
  1544. assert_size_stride(arg118_1, (), ())
  1545. assert_size_stride(arg119_1, (512, ), (1, ))
  1546. assert_size_stride(arg120_1, (512, ), (1, ))
  1547. assert_size_stride(arg121_1, (), ())
  1548. assert_size_stride(arg122_1, (1, 3, 228, 228), (155952, 51984, 228, 1))
  1549. with torch.cuda._DeviceGuard(0):
  1550. torch.cuda.set_device(0)
  1551. buf0 = empty_strided_cuda((1, 3, 228, 228), (155952, 1, 684, 3), torch.float32)
  1552. # Source Nodes: [x], Original ATen: [aten.convolution]
  1553. stream0 = get_raw_stream(0)
  1554. triton_poi_fused_convolution_0[grid(3, 51984)](arg122_1, buf0, 3, 51984, XBLOCK=256, YBLOCK=4, num_warps=4, num_stages=1)
  1555. del arg122_1
  1556. buf1 = empty_strided_cuda((64, 3, 7, 7), (147, 1, 21, 3), torch.float32)
  1557. # Source Nodes: [x], Original ATen: [aten.convolution]
  1558. triton_poi_fused_convolution_1[grid(192, 49)](arg0_1, buf1, 192, 49, XBLOCK=32, YBLOCK=32, num_warps=4, num_stages=1)
  1559. del arg0_1
  1560. # Source Nodes: [x], Original ATen: [aten.convolution]
  1561. buf2 = extern_kernels.convolution(buf0, buf1, stride=(2, 2), padding=(3, 3), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1562. assert_size_stride(buf2, (1, 64, 114, 114), (831744, 1, 7296, 64))
  1563. del buf0
  1564. del buf1
  1565. buf3 = empty_strided_cuda((1, 64, 1, 1, 2, 51), (6528, 1, 6528, 6528, 3264, 64), torch.float32)
  1566. buf4 = empty_strided_cuda((1, 64, 1, 1, 2, 51), (6528, 1, 6528, 6528, 3264, 64), torch.float32)
  1567. buf5 = empty_strided_cuda((1, 64, 1, 1, 2, 51), (6528, 1, 6528, 6528, 3264, 64), torch.float32)
  1568. # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
  1569. triton_per_fused__native_batch_norm_legit_functional_2[grid(6528)](buf2, buf3, buf4, buf5, 6528, 128, XBLOCK=32, num_warps=8, num_stages=1)
  1570. buf6 = empty_strided_cuda((1, 64, 1, 1, 2), (128, 1, 128, 128, 64), torch.float32)
  1571. buf7 = empty_strided_cuda((1, 64, 1, 1, 2), (128, 1, 128, 128, 64), torch.float32)
  1572. buf8 = empty_strided_cuda((1, 64, 1, 1, 2), (128, 1, 128, 128, 64), torch.float32)
  1573. # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
  1574. triton_per_fused__native_batch_norm_legit_functional_3[grid(128)](buf3, buf4, buf5, buf6, buf7, buf8, 128, 51, XBLOCK=1, num_warps=2, num_stages=1)
  1575. del buf3
  1576. del buf4
  1577. del buf5
  1578. buf9 = empty_strided_cuda((1, 64, 1, 1), (64, 1, 1, 1), torch.float32)
  1579. buf10 = empty_strided_cuda((1, 64, 1, 1), (64, 1, 64, 64), torch.float32)
  1580. # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
  1581. triton_per_fused__native_batch_norm_legit_functional_4[grid(64)](buf6, buf7, buf8, arg62_1, arg63_1, buf9, buf10, arg62_1, arg63_1, 64, 2, XBLOCK=8, num_warps=2, num_stages=1)
  1582. del arg62_1
  1583. del arg63_1
  1584. buf12 = buf2; del buf2 # reuse
  1585. # Source Nodes: [x_1, x_2], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  1586. triton_poi_fused__native_batch_norm_legit_functional_relu_5[grid(831744)](buf12, buf9, buf10, arg1_1, arg2_1, 831744, XBLOCK=1024, num_warps=4, num_stages=1)
  1587. del arg1_1
  1588. del arg2_1
  1589. buf13 = empty_strided_cuda((1, 64, 57, 57), (207936, 1, 3648, 64), torch.float32)
  1590. # Source Nodes: [x_1, x_2, x_3], Original ATen: [aten._native_batch_norm_legit_functional, aten.max_pool2d_with_indices, aten.relu]
  1591. triton_poi_fused__native_batch_norm_legit_functional_max_pool2d_with_indices_relu_6[grid(207936)](buf12, buf13, 207936, XBLOCK=512, num_warps=8, num_stages=1)
  1592. del buf12
  1593. buf14 = empty_strided_cuda((64, 64, 3, 3), (576, 1, 192, 64), torch.float32)
  1594. # Source Nodes: [out], Original ATen: [aten.convolution]
  1595. triton_poi_fused_convolution_7[grid(4096, 9)](arg3_1, buf14, 4096, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1596. del arg3_1
  1597. # Source Nodes: [out], Original ATen: [aten.convolution]
  1598. buf15 = extern_kernels.convolution(buf13, buf14, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1599. assert_size_stride(buf15, (1, 64, 57, 57), (207936, 1, 3648, 64))
  1600. buf16 = buf9; del buf9 # reuse
  1601. buf17 = buf10; del buf10 # reuse
  1602. # Source Nodes: [out_1], Original ATen: [aten._native_batch_norm_legit_functional]
  1603. triton_red_fused__native_batch_norm_legit_functional_8[grid(64)](buf15, arg65_1, arg66_1, buf16, buf17, arg65_1, arg66_1, 64, 3249, XBLOCK=1, RBLOCK=2048, num_warps=8, num_stages=1)
  1604. del arg65_1
  1605. del arg66_1
  1606. buf19 = empty_strided_cuda((1, 64, 57, 57), (207936, 1, 3648, 64), torch.float32)
  1607. # Source Nodes: [out_1, out_2], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  1608. triton_poi_fused__native_batch_norm_legit_functional_relu_9[grid(207936)](buf15, buf16, buf17, arg4_1, arg5_1, buf19, 207936, XBLOCK=1024, num_warps=4, num_stages=1)
  1609. del arg4_1
  1610. del arg5_1
  1611. del buf15
  1612. buf20 = buf14; del buf14 # reuse
  1613. # Source Nodes: [out_1, out_2, out_3], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1614. triton_poi_fused_convolution_7[grid(4096, 9)](arg6_1, buf20, 4096, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1615. del arg6_1
  1616. # Source Nodes: [out_1, out_2, out_3], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1617. buf21 = extern_kernels.convolution(buf19, buf20, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1618. assert_size_stride(buf21, (1, 64, 57, 57), (207936, 1, 3648, 64))
  1619. del buf19
  1620. buf22 = reinterpret_tensor(buf17, (1, 64, 1, 1), (64, 1, 1, 1), 0); del buf17 # reuse
  1621. buf23 = reinterpret_tensor(buf16, (1, 64, 1, 1), (64, 1, 64, 64), 0); del buf16 # reuse
  1622. # Source Nodes: [out_4], Original ATen: [aten._native_batch_norm_legit_functional]
  1623. triton_red_fused__native_batch_norm_legit_functional_8[grid(64)](buf21, arg68_1, arg69_1, buf22, buf23, arg68_1, arg69_1, 64, 3249, XBLOCK=1, RBLOCK=2048, num_warps=8, num_stages=1)
  1624. del arg68_1
  1625. del arg69_1
  1626. buf25 = buf13; del buf13 # reuse
  1627. # Source Nodes: [out_4, out_5, out_6], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  1628. triton_poi_fused__native_batch_norm_legit_functional_add_relu_10[grid(207936)](buf25, buf21, buf22, buf23, arg7_1, arg8_1, 207936, XBLOCK=512, num_warps=8, num_stages=1)
  1629. del arg7_1
  1630. del arg8_1
  1631. buf26 = buf20; del buf20 # reuse
  1632. # Source Nodes: [out_7], Original ATen: [aten.convolution]
  1633. triton_poi_fused_convolution_7[grid(4096, 9)](arg9_1, buf26, 4096, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1634. del arg9_1
  1635. # Source Nodes: [out_7], Original ATen: [aten.convolution]
  1636. buf27 = extern_kernels.convolution(buf25, buf26, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1637. assert_size_stride(buf27, (1, 64, 57, 57), (207936, 1, 3648, 64))
  1638. buf28 = reinterpret_tensor(buf23, (1, 64, 1, 1), (64, 1, 1, 1), 0); del buf23 # reuse
  1639. buf29 = reinterpret_tensor(buf22, (1, 64, 1, 1), (64, 1, 64, 64), 0); del buf22 # reuse
  1640. # Source Nodes: [out_8], Original ATen: [aten._native_batch_norm_legit_functional]
  1641. triton_red_fused__native_batch_norm_legit_functional_8[grid(64)](buf27, arg71_1, arg72_1, buf28, buf29, arg71_1, arg72_1, 64, 3249, XBLOCK=1, RBLOCK=2048, num_warps=8, num_stages=1)
  1642. del arg71_1
  1643. del arg72_1
  1644. buf31 = buf21; del buf21 # reuse
  1645. # Source Nodes: [out_8, out_9], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  1646. triton_poi_fused__native_batch_norm_legit_functional_relu_9[grid(207936)](buf27, buf28, buf29, arg10_1, arg11_1, buf31, 207936, XBLOCK=1024, num_warps=4, num_stages=1)
  1647. del arg10_1
  1648. del arg11_1
  1649. del buf27
  1650. buf32 = buf26; del buf26 # reuse
  1651. # Source Nodes: [out_10, out_8, out_9], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1652. triton_poi_fused_convolution_7[grid(4096, 9)](arg12_1, buf32, 4096, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1653. del arg12_1
  1654. # Source Nodes: [out_10, out_8, out_9], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1655. buf33 = extern_kernels.convolution(buf31, buf32, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1656. assert_size_stride(buf33, (1, 64, 57, 57), (207936, 1, 3648, 64))
  1657. del buf31
  1658. del buf32
  1659. buf34 = reinterpret_tensor(buf29, (1, 64, 1, 1), (64, 1, 1, 1), 0); del buf29 # reuse
  1660. buf35 = reinterpret_tensor(buf28, (1, 64, 1, 1), (64, 1, 64, 64), 0); del buf28 # reuse
  1661. # Source Nodes: [out_11], Original ATen: [aten._native_batch_norm_legit_functional]
  1662. triton_red_fused__native_batch_norm_legit_functional_8[grid(64)](buf33, arg74_1, arg75_1, buf34, buf35, arg74_1, arg75_1, 64, 3249, XBLOCK=1, RBLOCK=2048, num_warps=8, num_stages=1)
  1663. del arg74_1
  1664. del arg75_1
  1665. buf37 = buf25; del buf25 # reuse
  1666. # Source Nodes: [out_11, out_12, out_13], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  1667. triton_poi_fused__native_batch_norm_legit_functional_add_relu_10[grid(207936)](buf37, buf33, buf34, buf35, arg13_1, arg14_1, 207936, XBLOCK=512, num_warps=8, num_stages=1)
  1668. del arg13_1
  1669. del arg14_1
  1670. del buf33
  1671. del buf34
  1672. del buf35
  1673. buf38 = empty_strided_cuda((128, 64, 3, 3), (576, 1, 192, 64), torch.float32)
  1674. # Source Nodes: [out_14], Original ATen: [aten.convolution]
  1675. triton_poi_fused_convolution_11[grid(8192, 9)](arg15_1, buf38, 8192, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1676. del arg15_1
  1677. # Source Nodes: [out_14], Original ATen: [aten.convolution]
  1678. buf39 = extern_kernels.convolution(buf37, buf38, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1679. assert_size_stride(buf39, (1, 128, 29, 29), (107648, 1, 3712, 128))
  1680. del buf38
  1681. buf40 = reinterpret_tensor(buf8, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf8 # reuse
  1682. buf41 = reinterpret_tensor(buf7, (1, 128, 1, 1), (128, 1, 128, 128), 0); del buf7 # reuse
  1683. # Source Nodes: [out_15], Original ATen: [aten._native_batch_norm_legit_functional]
  1684. triton_per_fused__native_batch_norm_legit_functional_12[grid(128)](buf39, arg77_1, arg78_1, buf40, buf41, arg77_1, arg78_1, 128, 841, num_warps=8, num_stages=1)
  1685. del arg77_1
  1686. del arg78_1
  1687. buf43 = empty_strided_cuda((1, 128, 29, 29), (107648, 1, 3712, 128), torch.float32)
  1688. # Source Nodes: [out_15, out_16], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  1689. triton_poi_fused__native_batch_norm_legit_functional_relu_13[grid(107648)](buf39, buf40, buf41, arg16_1, arg17_1, buf43, 107648, XBLOCK=512, num_warps=8, num_stages=1)
  1690. del arg16_1
  1691. del arg17_1
  1692. del buf39
  1693. buf44 = empty_strided_cuda((128, 128, 3, 3), (1152, 1, 384, 128), torch.float32)
  1694. # Source Nodes: [out_15, out_16, out_17], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1695. triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_14[grid(16384, 9)](arg18_1, buf44, 16384, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1696. del arg18_1
  1697. # Source Nodes: [out_15, out_16, out_17], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1698. buf45 = extern_kernels.convolution(buf43, buf44, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1699. assert_size_stride(buf45, (1, 128, 29, 29), (107648, 1, 3712, 128))
  1700. buf46 = reinterpret_tensor(buf41, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf41 # reuse
  1701. buf47 = reinterpret_tensor(buf40, (1, 128, 1, 1), (128, 1, 128, 128), 0); del buf40 # reuse
  1702. # Source Nodes: [out_18], Original ATen: [aten._native_batch_norm_legit_functional]
  1703. triton_per_fused__native_batch_norm_legit_functional_12[grid(128)](buf45, arg80_1, arg81_1, buf46, buf47, arg80_1, arg81_1, 128, 841, num_warps=8, num_stages=1)
  1704. del arg80_1
  1705. del arg81_1
  1706. # Source Nodes: [getattr_l__self___layer2___0___downsample_0], Original ATen: [aten.convolution]
  1707. buf49 = extern_kernels.convolution(buf37, arg21_1, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1708. assert_size_stride(buf49, (1, 128, 29, 29), (107648, 1, 3712, 128))
  1709. del arg21_1
  1710. del buf37
  1711. buf50 = reinterpret_tensor(buf6, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf6 # reuse
  1712. buf51 = empty_strided_cuda((1, 128, 1, 1), (128, 1, 128, 128), torch.float32)
  1713. # Source Nodes: [identity], Original ATen: [aten._native_batch_norm_legit_functional]
  1714. triton_per_fused__native_batch_norm_legit_functional_12[grid(128)](buf49, arg83_1, arg84_1, buf50, buf51, arg83_1, arg84_1, 128, 841, num_warps=8, num_stages=1)
  1715. del arg83_1
  1716. del arg84_1
  1717. buf53 = buf43; del buf43 # reuse
  1718. buf54 = buf53; del buf53 # reuse
  1719. # Source Nodes: [identity, out_18, out_19, out_20], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  1720. triton_poi_fused__native_batch_norm_legit_functional_add_relu_15[grid(107648)](buf54, buf45, buf46, buf47, arg19_1, arg20_1, buf49, buf50, buf51, arg22_1, arg23_1, 107648, XBLOCK=512, num_warps=8, num_stages=1)
  1721. del arg19_1
  1722. del arg20_1
  1723. del arg22_1
  1724. del arg23_1
  1725. del buf45
  1726. del buf46
  1727. del buf47
  1728. buf55 = buf44; del buf44 # reuse
  1729. # Source Nodes: [out_20, out_21], Original ATen: [aten.convolution, aten.relu]
  1730. triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_14[grid(16384, 9)](arg24_1, buf55, 16384, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1731. del arg24_1
  1732. # Source Nodes: [out_20, out_21], Original ATen: [aten.convolution, aten.relu]
  1733. buf56 = extern_kernels.convolution(buf54, buf55, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1734. assert_size_stride(buf56, (1, 128, 29, 29), (107648, 1, 3712, 128))
  1735. buf57 = reinterpret_tensor(buf51, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf51 # reuse
  1736. buf58 = reinterpret_tensor(buf50, (1, 128, 1, 1), (128, 1, 128, 128), 0); del buf50 # reuse
  1737. # Source Nodes: [out_22], Original ATen: [aten._native_batch_norm_legit_functional]
  1738. triton_per_fused__native_batch_norm_legit_functional_12[grid(128)](buf56, arg86_1, arg87_1, buf57, buf58, arg86_1, arg87_1, 128, 841, num_warps=8, num_stages=1)
  1739. del arg86_1
  1740. del arg87_1
  1741. buf60 = buf49; del buf49 # reuse
  1742. # Source Nodes: [out_22, out_23], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  1743. triton_poi_fused__native_batch_norm_legit_functional_relu_13[grid(107648)](buf56, buf57, buf58, arg25_1, arg26_1, buf60, 107648, XBLOCK=512, num_warps=8, num_stages=1)
  1744. del arg25_1
  1745. del arg26_1
  1746. del buf56
  1747. buf61 = buf55; del buf55 # reuse
  1748. # Source Nodes: [out_22, out_23, out_24], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1749. triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_14[grid(16384, 9)](arg27_1, buf61, 16384, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1750. del arg27_1
  1751. # Source Nodes: [out_22, out_23, out_24], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1752. buf62 = extern_kernels.convolution(buf60, buf61, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1753. assert_size_stride(buf62, (1, 128, 29, 29), (107648, 1, 3712, 128))
  1754. del buf60
  1755. del buf61
  1756. buf63 = reinterpret_tensor(buf58, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf58 # reuse
  1757. buf64 = reinterpret_tensor(buf57, (1, 128, 1, 1), (128, 1, 128, 128), 0); del buf57 # reuse
  1758. # Source Nodes: [out_25], Original ATen: [aten._native_batch_norm_legit_functional]
  1759. triton_per_fused__native_batch_norm_legit_functional_12[grid(128)](buf62, arg89_1, arg90_1, buf63, buf64, arg89_1, arg90_1, 128, 841, num_warps=8, num_stages=1)
  1760. del arg89_1
  1761. del arg90_1
  1762. buf66 = buf54; del buf54 # reuse
  1763. # Source Nodes: [out_25, out_26, out_27], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  1764. triton_poi_fused__native_batch_norm_legit_functional_add_relu_16[grid(107648)](buf66, buf62, buf63, buf64, arg28_1, arg29_1, 107648, XBLOCK=512, num_warps=8, num_stages=1)
  1765. del arg28_1
  1766. del arg29_1
  1767. del buf62
  1768. del buf63
  1769. del buf64
  1770. buf67 = empty_strided_cuda((256, 128, 3, 3), (1152, 1, 384, 128), torch.float32)
  1771. # Source Nodes: [out_28], Original ATen: [aten.convolution]
  1772. triton_poi_fused_convolution_17[grid(32768, 9)](arg30_1, buf67, 32768, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1773. del arg30_1
  1774. # Source Nodes: [out_28], Original ATen: [aten.convolution]
  1775. buf68 = extern_kernels.convolution(buf66, buf67, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1776. assert_size_stride(buf68, (1, 256, 15, 15), (57600, 1, 3840, 256))
  1777. del buf67
  1778. buf69 = empty_strided_cuda((1, 256, 1, 1), (256, 1, 1, 1), torch.float32)
  1779. buf70 = empty_strided_cuda((1, 256, 1, 1), (256, 1, 256, 256), torch.float32)
  1780. # Source Nodes: [out_29], Original ATen: [aten._native_batch_norm_legit_functional]
  1781. triton_per_fused__native_batch_norm_legit_functional_18[grid(256)](buf68, arg92_1, arg93_1, buf69, buf70, arg92_1, arg93_1, 256, 225, XBLOCK=1, num_warps=2, num_stages=1)
  1782. del arg92_1
  1783. del arg93_1
  1784. buf72 = empty_strided_cuda((1, 256, 15, 15), (57600, 1, 3840, 256), torch.float32)
  1785. # Source Nodes: [out_29, out_30], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  1786. triton_poi_fused__native_batch_norm_legit_functional_relu_19[grid(57600)](buf68, buf69, buf70, arg31_1, arg32_1, buf72, 57600, XBLOCK=512, num_warps=4, num_stages=1)
  1787. del arg31_1
  1788. del arg32_1
  1789. del buf68
  1790. buf73 = empty_strided_cuda((256, 256, 3, 3), (2304, 1, 768, 256), torch.float32)
  1791. # Source Nodes: [out_29, out_30, out_31], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1792. triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_20[grid(65536, 9)](arg33_1, buf73, 65536, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1793. del arg33_1
  1794. # Source Nodes: [out_29, out_30, out_31], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1795. buf74 = extern_kernels.convolution(buf72, buf73, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1796. assert_size_stride(buf74, (1, 256, 15, 15), (57600, 1, 3840, 256))
  1797. buf75 = reinterpret_tensor(buf70, (1, 256, 1, 1), (256, 1, 1, 1), 0); del buf70 # reuse
  1798. buf76 = reinterpret_tensor(buf69, (1, 256, 1, 1), (256, 1, 256, 256), 0); del buf69 # reuse
  1799. # Source Nodes: [out_32], Original ATen: [aten._native_batch_norm_legit_functional]
  1800. triton_per_fused__native_batch_norm_legit_functional_18[grid(256)](buf74, arg95_1, arg96_1, buf75, buf76, arg95_1, arg96_1, 256, 225, XBLOCK=1, num_warps=2, num_stages=1)
  1801. del arg95_1
  1802. del arg96_1
  1803. # Source Nodes: [getattr_l__self___layer3___0___downsample_0], Original ATen: [aten.convolution]
  1804. buf78 = extern_kernels.convolution(buf66, arg36_1, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1805. assert_size_stride(buf78, (1, 256, 15, 15), (57600, 1, 3840, 256))
  1806. del arg36_1
  1807. del buf66
  1808. buf79 = empty_strided_cuda((1, 256, 1, 1), (256, 1, 1, 1), torch.float32)
  1809. buf80 = empty_strided_cuda((1, 256, 1, 1), (256, 1, 256, 256), torch.float32)
  1810. # Source Nodes: [identity_1], Original ATen: [aten._native_batch_norm_legit_functional]
  1811. triton_per_fused__native_batch_norm_legit_functional_18[grid(256)](buf78, arg98_1, arg99_1, buf79, buf80, arg98_1, arg99_1, 256, 225, XBLOCK=1, num_warps=2, num_stages=1)
  1812. del arg98_1
  1813. del arg99_1
  1814. buf82 = buf72; del buf72 # reuse
  1815. buf83 = buf82; del buf82 # reuse
  1816. # Source Nodes: [identity_1, out_32, out_33, out_34], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  1817. triton_poi_fused__native_batch_norm_legit_functional_add_relu_21[grid(57600)](buf83, buf74, buf75, buf76, arg34_1, arg35_1, buf78, buf79, buf80, arg37_1, arg38_1, 57600, XBLOCK=256, num_warps=4, num_stages=1)
  1818. del arg34_1
  1819. del arg35_1
  1820. del arg37_1
  1821. del arg38_1
  1822. del buf74
  1823. del buf75
  1824. del buf76
  1825. buf84 = buf73; del buf73 # reuse
  1826. # Source Nodes: [out_34, out_35], Original ATen: [aten.convolution, aten.relu]
  1827. triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_20[grid(65536, 9)](arg39_1, buf84, 65536, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1828. del arg39_1
  1829. # Source Nodes: [out_34, out_35], Original ATen: [aten.convolution, aten.relu]
  1830. buf85 = extern_kernels.convolution(buf83, buf84, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1831. assert_size_stride(buf85, (1, 256, 15, 15), (57600, 1, 3840, 256))
  1832. buf86 = reinterpret_tensor(buf80, (1, 256, 1, 1), (256, 1, 1, 1), 0); del buf80 # reuse
  1833. buf87 = reinterpret_tensor(buf79, (1, 256, 1, 1), (256, 1, 256, 256), 0); del buf79 # reuse
  1834. # Source Nodes: [out_36], Original ATen: [aten._native_batch_norm_legit_functional]
  1835. triton_per_fused__native_batch_norm_legit_functional_18[grid(256)](buf85, arg101_1, arg102_1, buf86, buf87, arg101_1, arg102_1, 256, 225, XBLOCK=1, num_warps=2, num_stages=1)
  1836. del arg101_1
  1837. del arg102_1
  1838. buf89 = buf78; del buf78 # reuse
  1839. # Source Nodes: [out_36, out_37], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  1840. triton_poi_fused__native_batch_norm_legit_functional_relu_19[grid(57600)](buf85, buf86, buf87, arg40_1, arg41_1, buf89, 57600, XBLOCK=512, num_warps=4, num_stages=1)
  1841. del arg40_1
  1842. del arg41_1
  1843. del buf85
  1844. buf90 = buf84; del buf84 # reuse
  1845. # Source Nodes: [out_36, out_37, out_38], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1846. triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_20[grid(65536, 9)](arg42_1, buf90, 65536, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1847. del arg42_1
  1848. # Source Nodes: [out_36, out_37, out_38], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1849. buf91 = extern_kernels.convolution(buf89, buf90, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1850. assert_size_stride(buf91, (1, 256, 15, 15), (57600, 1, 3840, 256))
  1851. del buf89
  1852. del buf90
  1853. buf92 = reinterpret_tensor(buf87, (1, 256, 1, 1), (256, 1, 1, 1), 0); del buf87 # reuse
  1854. buf93 = reinterpret_tensor(buf86, (1, 256, 1, 1), (256, 1, 256, 256), 0); del buf86 # reuse
  1855. # Source Nodes: [out_39], Original ATen: [aten._native_batch_norm_legit_functional]
  1856. triton_per_fused__native_batch_norm_legit_functional_18[grid(256)](buf91, arg104_1, arg105_1, buf92, buf93, arg104_1, arg105_1, 256, 225, XBLOCK=1, num_warps=2, num_stages=1)
  1857. del arg104_1
  1858. del arg105_1
  1859. buf95 = buf83; del buf83 # reuse
  1860. # Source Nodes: [out_39, out_40, out_41], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  1861. triton_poi_fused__native_batch_norm_legit_functional_add_relu_22[grid(57600)](buf95, buf91, buf92, buf93, arg43_1, arg44_1, 57600, XBLOCK=256, num_warps=4, num_stages=1)
  1862. del arg43_1
  1863. del arg44_1
  1864. del buf91
  1865. del buf92
  1866. del buf93
  1867. buf96 = empty_strided_cuda((512, 256, 3, 3), (2304, 1, 768, 256), torch.float32)
  1868. # Source Nodes: [out_42], Original ATen: [aten.convolution]
  1869. triton_poi_fused_convolution_23[grid(131072, 9)](arg45_1, buf96, 131072, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1870. del arg45_1
  1871. # Source Nodes: [out_42], Original ATen: [aten.convolution]
  1872. buf97 = extern_kernels.convolution(buf95, buf96, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1873. assert_size_stride(buf97, (1, 512, 8, 8), (32768, 1, 4096, 512))
  1874. del buf96
  1875. buf98 = empty_strided_cuda((1, 512, 1, 1), (512, 1, 1, 1), torch.float32)
  1876. buf99 = empty_strided_cuda((1, 512, 1, 1), (512, 1, 512, 512), torch.float32)
  1877. # Source Nodes: [out_43], Original ATen: [aten._native_batch_norm_legit_functional]
  1878. triton_per_fused__native_batch_norm_legit_functional_24[grid(512)](buf97, arg107_1, arg108_1, buf98, buf99, arg107_1, arg108_1, 512, 64, XBLOCK=32, num_warps=8, num_stages=1)
  1879. del arg107_1
  1880. del arg108_1
  1881. buf101 = empty_strided_cuda((1, 512, 8, 8), (32768, 1, 4096, 512), torch.float32)
  1882. # Source Nodes: [out_43, out_44], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  1883. triton_poi_fused__native_batch_norm_legit_functional_relu_25[grid(32768)](buf97, buf98, buf99, arg46_1, arg47_1, buf101, 32768, XBLOCK=128, num_warps=4, num_stages=1)
  1884. del arg46_1
  1885. del arg47_1
  1886. del buf97
  1887. buf102 = empty_strided_cuda((512, 512, 3, 3), (4608, 1, 1536, 512), torch.float32)
  1888. # Source Nodes: [out_43, out_44, out_45], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1889. triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_26[grid(262144, 9)](arg48_1, buf102, 262144, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1890. del arg48_1
  1891. # Source Nodes: [out_43, out_44, out_45], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1892. buf103 = extern_kernels.convolution(buf101, buf102, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1893. assert_size_stride(buf103, (1, 512, 8, 8), (32768, 1, 4096, 512))
  1894. buf104 = reinterpret_tensor(buf99, (1, 512, 1, 1), (512, 1, 1, 1), 0); del buf99 # reuse
  1895. buf105 = reinterpret_tensor(buf98, (1, 512, 1, 1), (512, 1, 512, 512), 0); del buf98 # reuse
  1896. # Source Nodes: [out_46], Original ATen: [aten._native_batch_norm_legit_functional]
  1897. triton_per_fused__native_batch_norm_legit_functional_24[grid(512)](buf103, arg110_1, arg111_1, buf104, buf105, arg110_1, arg111_1, 512, 64, XBLOCK=32, num_warps=8, num_stages=1)
  1898. del arg110_1
  1899. del arg111_1
  1900. # Source Nodes: [getattr_l__self___layer4___0___downsample_0], Original ATen: [aten.convolution]
  1901. buf107 = extern_kernels.convolution(buf95, arg51_1, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1902. assert_size_stride(buf107, (1, 512, 8, 8), (32768, 1, 4096, 512))
  1903. del arg51_1
  1904. del buf95
  1905. buf108 = empty_strided_cuda((1, 512, 1, 1), (512, 1, 1, 1), torch.float32)
  1906. buf109 = empty_strided_cuda((1, 512, 1, 1), (512, 1, 512, 512), torch.float32)
  1907. # Source Nodes: [identity_2], Original ATen: [aten._native_batch_norm_legit_functional]
  1908. triton_per_fused__native_batch_norm_legit_functional_24[grid(512)](buf107, arg113_1, arg114_1, buf108, buf109, arg113_1, arg114_1, 512, 64, XBLOCK=32, num_warps=8, num_stages=1)
  1909. del arg113_1
  1910. del arg114_1
  1911. buf111 = buf101; del buf101 # reuse
  1912. buf112 = buf111; del buf111 # reuse
  1913. # Source Nodes: [identity_2, out_46, out_47, out_48], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
  1914. triton_poi_fused__native_batch_norm_legit_functional_add_relu_27[grid(32768)](buf112, buf103, buf104, buf105, arg49_1, arg50_1, buf107, buf108, buf109, arg52_1, arg53_1, 32768, XBLOCK=256, num_warps=4, num_stages=1)
  1915. del arg49_1
  1916. del arg50_1
  1917. del arg52_1
  1918. del arg53_1
  1919. del buf103
  1920. del buf104
  1921. del buf105
  1922. buf113 = buf102; del buf102 # reuse
  1923. # Source Nodes: [out_48, out_49], Original ATen: [aten.convolution, aten.relu]
  1924. triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_26[grid(262144, 9)](arg54_1, buf113, 262144, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1925. del arg54_1
  1926. # Source Nodes: [out_48, out_49], Original ATen: [aten.convolution, aten.relu]
  1927. buf114 = extern_kernels.convolution(buf112, buf113, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1928. assert_size_stride(buf114, (1, 512, 8, 8), (32768, 1, 4096, 512))
  1929. buf115 = reinterpret_tensor(buf109, (1, 512, 1, 1), (512, 1, 1, 1), 0); del buf109 # reuse
  1930. buf116 = reinterpret_tensor(buf108, (1, 512, 1, 1), (512, 1, 512, 512), 0); del buf108 # reuse
  1931. # Source Nodes: [out_50], Original ATen: [aten._native_batch_norm_legit_functional]
  1932. triton_per_fused__native_batch_norm_legit_functional_24[grid(512)](buf114, arg116_1, arg117_1, buf115, buf116, arg116_1, arg117_1, 512, 64, XBLOCK=32, num_warps=8, num_stages=1)
  1933. del arg116_1
  1934. del arg117_1
  1935. buf118 = buf107; del buf107 # reuse
  1936. # Source Nodes: [out_50, out_51], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
  1937. triton_poi_fused__native_batch_norm_legit_functional_relu_25[grid(32768)](buf114, buf115, buf116, arg55_1, arg56_1, buf118, 32768, XBLOCK=128, num_warps=4, num_stages=1)
  1938. del arg55_1
  1939. del arg56_1
  1940. del buf114
  1941. del buf115
  1942. buf119 = buf113; del buf113 # reuse
  1943. # Source Nodes: [out_50, out_51, out_52], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1944. triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_26[grid(262144, 9)](arg57_1, buf119, 262144, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
  1945. del arg57_1
  1946. # Source Nodes: [out_50, out_51, out_52], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
  1947. buf120 = extern_kernels.convolution(buf118, buf119, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
  1948. assert_size_stride(buf120, (1, 512, 8, 8), (32768, 1, 4096, 512))
  1949. del buf118
  1950. del buf119
  1951. buf124 = buf116; del buf116 # reuse
  1952. buf125 = buf124; del buf124 # reuse
  1953. # Source Nodes: [out_53, out_54, out_55, x_4], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.mean, aten.relu]
  1954. triton_per_fused__native_batch_norm_legit_functional_add_mean_relu_28[grid(512)](buf125, buf120, arg58_1, arg59_1, buf112, arg119_1, arg120_1, arg119_1, arg120_1, 512, 64, XBLOCK=32, num_warps=8, num_stages=1)
  1955. del arg119_1
  1956. del arg120_1
  1957. del arg58_1
  1958. del arg59_1
  1959. del buf112
  1960. del buf120
  1961. buf126 = empty_strided_cuda((1, 1000), (1000, 1), torch.float32)
  1962. # Source Nodes: [x_6], Original ATen: [aten.addmm]
  1963. extern_kernels.addmm(arg61_1, reinterpret_tensor(buf125, (1, 512), (0, 1), 0), reinterpret_tensor(arg60_1, (512, 1000), (1, 512), 0), alpha=1, beta=1, out=buf126)
  1964. del arg60_1
  1965. del arg61_1
  1966. del buf125
  1967. # Source Nodes: [x_1], Original ATen: [aten.add]
  1968. triton_poi_fused_add_29[grid(1)](arg64_1, arg64_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  1969. del arg64_1
  1970. # Source Nodes: [out_1], Original ATen: [aten.add]
  1971. triton_poi_fused_add_29[grid(1)](arg67_1, arg67_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  1972. del arg67_1
  1973. # Source Nodes: [out_4], Original ATen: [aten.add]
  1974. triton_poi_fused_add_29[grid(1)](arg70_1, arg70_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  1975. del arg70_1
  1976. # Source Nodes: [out_8], Original ATen: [aten.add]
  1977. triton_poi_fused_add_29[grid(1)](arg73_1, arg73_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  1978. del arg73_1
  1979. # Source Nodes: [out_11], Original ATen: [aten.add]
  1980. triton_poi_fused_add_29[grid(1)](arg76_1, arg76_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  1981. del arg76_1
  1982. # Source Nodes: [out_15], Original ATen: [aten.add]
  1983. triton_poi_fused_add_29[grid(1)](arg79_1, arg79_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  1984. del arg79_1
  1985. # Source Nodes: [out_18], Original ATen: [aten.add]
  1986. triton_poi_fused_add_29[grid(1)](arg82_1, arg82_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  1987. del arg82_1
  1988. # Source Nodes: [identity], Original ATen: [aten.add]
  1989. triton_poi_fused_add_29[grid(1)](arg85_1, arg85_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  1990. del arg85_1
  1991. # Source Nodes: [out_22], Original ATen: [aten.add]
  1992. triton_poi_fused_add_29[grid(1)](arg88_1, arg88_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  1993. del arg88_1
  1994. # Source Nodes: [out_25], Original ATen: [aten.add]
  1995. triton_poi_fused_add_29[grid(1)](arg91_1, arg91_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  1996. del arg91_1
  1997. # Source Nodes: [out_29], Original ATen: [aten.add]
  1998. triton_poi_fused_add_29[grid(1)](arg94_1, arg94_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  1999. del arg94_1
  2000. # Source Nodes: [out_32], Original ATen: [aten.add]
  2001. triton_poi_fused_add_29[grid(1)](arg97_1, arg97_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  2002. del arg97_1
  2003. # Source Nodes: [identity_1], Original ATen: [aten.add]
  2004. triton_poi_fused_add_29[grid(1)](arg100_1, arg100_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  2005. del arg100_1
  2006. # Source Nodes: [out_36], Original ATen: [aten.add]
  2007. triton_poi_fused_add_29[grid(1)](arg103_1, arg103_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  2008. del arg103_1
  2009. # Source Nodes: [out_39], Original ATen: [aten.add]
  2010. triton_poi_fused_add_29[grid(1)](arg106_1, arg106_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  2011. del arg106_1
  2012. # Source Nodes: [out_43], Original ATen: [aten.add]
  2013. triton_poi_fused_add_29[grid(1)](arg109_1, arg109_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  2014. del arg109_1
  2015. # Source Nodes: [out_46], Original ATen: [aten.add]
  2016. triton_poi_fused_add_29[grid(1)](arg112_1, arg112_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  2017. del arg112_1
  2018. # Source Nodes: [identity_2], Original ATen: [aten.add]
  2019. triton_poi_fused_add_29[grid(1)](arg115_1, arg115_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  2020. del arg115_1
  2021. # Source Nodes: [out_50], Original ATen: [aten.add]
  2022. triton_poi_fused_add_29[grid(1)](arg118_1, arg118_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  2023. del arg118_1
  2024. # Source Nodes: [out_53], Original ATen: [aten.add]
  2025. triton_poi_fused_add_29[grid(1)](arg121_1, arg121_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
  2026. del arg121_1
  2027. return (buf126, )
  2028.  
  2029.  
  2030. def benchmark_compiled_module(times=10, repeat=10):
  2031. from torch._dynamo.testing import rand_strided
  2032. from torch._inductor.utils import print_performance
  2033. arg0_1 = rand_strided((64, 3, 7, 7), (147, 49, 7, 1), device='cuda:0', dtype=torch.float32)
  2034. arg1_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2035. arg2_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2036. arg3_1 = rand_strided((64, 64, 3, 3), (576, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2037. arg4_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2038. arg5_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2039. arg6_1 = rand_strided((64, 64, 3, 3), (576, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2040. arg7_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2041. arg8_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2042. arg9_1 = rand_strided((64, 64, 3, 3), (576, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2043. arg10_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2044. arg11_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2045. arg12_1 = rand_strided((64, 64, 3, 3), (576, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2046. arg13_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2047. arg14_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2048. arg15_1 = rand_strided((128, 64, 3, 3), (576, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2049. arg16_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2050. arg17_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2051. arg18_1 = rand_strided((128, 128, 3, 3), (1152, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2052. arg19_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2053. arg20_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2054. arg21_1 = rand_strided((128, 64, 1, 1), (64, 1, 1, 1), device='cuda:0', dtype=torch.float32)
  2055. arg22_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2056. arg23_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2057. arg24_1 = rand_strided((128, 128, 3, 3), (1152, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2058. arg25_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2059. arg26_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2060. arg27_1 = rand_strided((128, 128, 3, 3), (1152, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2061. arg28_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2062. arg29_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2063. arg30_1 = rand_strided((256, 128, 3, 3), (1152, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2064. arg31_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2065. arg32_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2066. arg33_1 = rand_strided((256, 256, 3, 3), (2304, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2067. arg34_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2068. arg35_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2069. arg36_1 = rand_strided((256, 128, 1, 1), (128, 1, 1, 1), device='cuda:0', dtype=torch.float32)
  2070. arg37_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2071. arg38_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2072. arg39_1 = rand_strided((256, 256, 3, 3), (2304, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2073. arg40_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2074. arg41_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2075. arg42_1 = rand_strided((256, 256, 3, 3), (2304, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2076. arg43_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2077. arg44_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2078. arg45_1 = rand_strided((512, 256, 3, 3), (2304, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2079. arg46_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2080. arg47_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2081. arg48_1 = rand_strided((512, 512, 3, 3), (4608, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2082. arg49_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2083. arg50_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2084. arg51_1 = rand_strided((512, 256, 1, 1), (256, 1, 1, 1), device='cuda:0', dtype=torch.float32)
  2085. arg52_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2086. arg53_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2087. arg54_1 = rand_strided((512, 512, 3, 3), (4608, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2088. arg55_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2089. arg56_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2090. arg57_1 = rand_strided((512, 512, 3, 3), (4608, 9, 3, 1), device='cuda:0', dtype=torch.float32)
  2091. arg58_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2092. arg59_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2093. arg60_1 = rand_strided((1000, 512), (512, 1), device='cuda:0', dtype=torch.float32)
  2094. arg61_1 = rand_strided((1000, ), (1, ), device='cuda:0', dtype=torch.float32)
  2095. arg62_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2096. arg63_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2097. arg64_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2098. arg65_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2099. arg66_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2100. arg67_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2101. arg68_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2102. arg69_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2103. arg70_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2104. arg71_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2105. arg72_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2106. arg73_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2107. arg74_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2108. arg75_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
  2109. arg76_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2110. arg77_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2111. arg78_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2112. arg79_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2113. arg80_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2114. arg81_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2115. arg82_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2116. arg83_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2117. arg84_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2118. arg85_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2119. arg86_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2120. arg87_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2121. arg88_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2122. arg89_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2123. arg90_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
  2124. arg91_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2125. arg92_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2126. arg93_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2127. arg94_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2128. arg95_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2129. arg96_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2130. arg97_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2131. arg98_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2132. arg99_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2133. arg100_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2134. arg101_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2135. arg102_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2136. arg103_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2137. arg104_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2138. arg105_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
  2139. arg106_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2140. arg107_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2141. arg108_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2142. arg109_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2143. arg110_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2144. arg111_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2145. arg112_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2146. arg113_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2147. arg114_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2148. arg115_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2149. arg116_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2150. arg117_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2151. arg118_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2152. arg119_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2153. arg120_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
  2154. arg121_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
  2155. arg122_1 = rand_strided((1, 3, 228, 228), (155952, 51984, 228, 1), device='cuda:0', dtype=torch.float32)
  2156. fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1])
  2157. return print_performance(fn, times=times, repeat=repeat)
  2158.  
  2159.  
  2160. if __name__ == "__main__":
  2161. from torch._inductor.wrapper_benchmark import compiled_module_main
  2162. compiled_module_main('None', benchmark_compiled_module)
  2163.  
Advertisement
Add Comment
Please, Sign In to add comment