Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # AOT ID: ['0_inference']
- from ctypes import c_void_p, c_long
- import torch
- import math
- import random
- import os
- import tempfile
- from math import inf, nan
- from torch._inductor.hooks import run_intermediate_hooks
- from torch._inductor.utils import maybe_profile
- from torch._inductor.codegen.memory_planning import _align as align
- from torch import device, empty_strided
- from torch._inductor.codecache import AsyncCompile
- from torch._inductor.select_algorithm import extern_kernels
- from torch._inductor.codegen.multi_kernel import MultiKernelCall
- aten = torch.ops.aten
- inductor_ops = torch.ops.inductor
- _quantized = torch.ops._quantized
- assert_size_stride = torch._C._dynamo.guards.assert_size_stride
- empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
- empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
- alloc_from_pool = torch.ops.inductor._alloc_from_pool
- reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
- # kernel path: /tmp/torchinductor_chilli/ey/ceyko6dtcfce7u2l2mwcbbrcenhuy63ap6ko6ppldgjdojixko7k.py
- # Source Nodes: [x], Original ATen: [aten.convolution]
- # x => convolution
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused_convolution_0(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
- ynumel = 3
- xnumel = 51984
- yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
- yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
- ymask = yindex < ynumel
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- x1 = xindex
- y0 = yindex
- tmp0 = tl.load(in_ptr0 + (x1 + (51984*y0)), xmask & ymask, eviction_policy='evict_last')
- tl.store(out_ptr0 + (y0 + (3*x1)), tmp0, xmask & ymask)
- import triton
- import triton.language as tl
- from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
- from torch._C import _cuda_getCurrentRawStream as get_raw_stream
- # kernel path: /tmp/torchinductor_chilli/fg/cfgog4hjgpnrqznjhdl7s57zvld6hr66mq3j44s7xmbtaewazh33.py
- # Source Nodes: [x], Original ATen: [aten.convolution]
- # x => convolution
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused_convolution_1(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
- ynumel = 192
- xnumel = 49
- yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
- yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
- ymask = yindex < ynumel
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- x2 = xindex
- y3 = yindex
- y0 = yindex % 3
- y1 = (yindex // 3)
- tmp0 = tl.load(in_ptr0 + (x2 + (49*y3)), xmask & ymask, eviction_policy='evict_last')
- tl.store(out_ptr0 + (y0 + (3*x2) + (147*y1)), tmp0, xmask & ymask)
- # kernel path: /tmp/torchinductor_chilli/oy/coyvppj4wtlyylbacoaiwv7r5c2bcteoyqnizkfxwnizxz3hmqxh.py
- # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
- # x_1 => var_mean
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_per_fused__native_batch_norm_legit_functional_2(in_ptr0, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
- xnumel = 6528
- rnumel = 128
- RBLOCK: tl.constexpr = 128
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- rindex = tl.arange(0, RBLOCK)[None, :]
- roffset = 0
- rmask = rindex < rnumel
- r3 = rindex
- x1 = (xindex // 64) % 51
- x0 = xindex % 64
- x2 = (xindex // 3264)
- x4 = xindex
- tmp0 = r3 + (128*x1)
- tmp1 = tl.full([1, 1], 6498, tl.int32)
- tmp2 = tmp0 < tmp1
- tmp3 = tl.load(in_ptr0 + (x0 + (64*((r3 + (128*x1)) % 114)) + (7296*(((r3 + (128*x1) + (6498*x2)) // 114) % 114))), rmask & tmp2 & xmask, other=0.0)
- tmp4 = tl.full(tmp3.shape, 0, tmp3.dtype)
- tmp5 = tl.where(tmp2, tmp3, tmp4)
- tmp6 = 0.0
- tmp7 = tl.full(tmp6.shape, 0, tmp6.dtype)
- tmp8 = tl.where(tmp2, tmp6, tmp7)
- tmp9 = 1.0
- tmp10 = tl.full(tmp9.shape, 0, tmp9.dtype)
- tmp11 = tl.where(tmp2, tmp9, tmp10)
- tmp12 = tl.broadcast_to(tmp5, [XBLOCK, RBLOCK])
- tmp13 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
- tmp14 = tl.broadcast_to(tmp11, [XBLOCK, RBLOCK])
- tmp16 = tl.where(rmask & xmask, tmp12, 0)
- tmp17 = tl.where(rmask & xmask, tmp13, 0)
- tmp18 = tl.where(rmask & xmask, tmp14, 0)
- tmp19, tmp20, tmp21 = triton_helpers.welford(tmp16, tmp17, tmp18, 1)
- tmp22 = tmp19[:, None]
- tmp23 = tmp20[:, None]
- tmp24 = tmp21[:, None]
- tl.store(out_ptr0 + (x4), tmp22, xmask)
- tl.store(out_ptr1 + (x4), tmp23, xmask)
- tl.store(out_ptr2 + (x4), tmp24, xmask)
- # kernel path: /tmp/torchinductor_chilli/hc/chcl7tgr3scofksd7kbezuz3nvq5wqe3nivxob6lllb27qsootcq.py
- # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
- # x_1 => var_mean
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_per_fused__native_batch_norm_legit_functional_3(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
- xnumel = 128
- rnumel = 51
- RBLOCK: tl.constexpr = 64
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- rindex = tl.arange(0, RBLOCK)[None, :]
- roffset = 0
- rmask = rindex < rnumel
- r2 = rindex
- x0 = xindex % 64
- x1 = (xindex // 64)
- x3 = xindex
- tmp0 = tl.load(in_ptr0 + (x0 + (64*r2) + (3264*x1)), rmask & xmask, other=0.0)
- tmp1 = tl.load(in_ptr1 + (x0 + (64*r2) + (3264*x1)), rmask & xmask, other=0.0)
- tmp2 = tl.load(in_ptr2 + (x0 + (64*r2) + (3264*x1)), rmask & xmask, other=0.0)
- tmp3 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
- tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
- tmp5 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
- tmp7 = tl.where(rmask & xmask, tmp3, 0)
- tmp8 = tl.where(rmask & xmask, tmp4, 0)
- tmp9 = tl.where(rmask & xmask, tmp5, 0)
- tmp10, tmp11, tmp12 = triton_helpers.welford(tmp7, tmp8, tmp9, 1)
- tmp13 = tmp10[:, None]
- tmp14 = tmp11[:, None]
- tmp15 = tmp12[:, None]
- tl.store(out_ptr0 + (x3), tmp13, xmask)
- tl.store(out_ptr1 + (x3), tmp14, xmask)
- tl.store(out_ptr2 + (x3), tmp15, xmask)
- # kernel path: /tmp/torchinductor_chilli/7u/c7uiyrzseejytdgzkmlnbflxex6dricmsk7lfiz42tf55su5rnn5.py
- # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
- # x_1 => add_2, add_3, mul_1, mul_2, mul_3, mul_4, mul_5, var_mean
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_per_fused__native_batch_norm_legit_functional_4(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, rnumel, XBLOCK : tl.constexpr):
- xnumel = 64
- rnumel = 2
- RBLOCK: tl.constexpr = 2
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- rindex = tl.arange(0, RBLOCK)[None, :]
- roffset = 0
- rmask = rindex < rnumel
- r1 = rindex
- x0 = xindex
- tmp0 = tl.load(in_ptr0 + (x0 + (64*r1)), rmask & xmask, other=0.0)
- tmp1 = tl.load(in_ptr1 + (x0 + (64*r1)), rmask & xmask, other=0.0)
- tmp2 = tl.load(in_ptr2 + (x0 + (64*r1)), rmask & xmask, other=0.0)
- tmp18 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
- tmp27 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
- tmp3 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
- tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
- tmp5 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
- tmp7 = tl.where(rmask & xmask, tmp3, 0)
- tmp8 = tl.where(rmask & xmask, tmp4, 0)
- tmp9 = tl.where(rmask & xmask, tmp5, 0)
- tmp10, tmp11, tmp12 = triton_helpers.welford(tmp7, tmp8, tmp9, 1)
- tmp13 = tmp10[:, None]
- tmp14 = tmp11[:, None]
- tmp15 = tmp12[:, None]
- tmp16 = 0.1
- tmp17 = tmp13 * tmp16
- tmp19 = 0.9
- tmp20 = tmp18 * tmp19
- tmp21 = tmp17 + tmp20
- tmp22 = 12996.0
- tmp23 = tmp14 / tmp22
- tmp24 = 1.0000769526741053
- tmp25 = tmp23 * tmp24
- tmp26 = tmp25 * tmp16
- tmp28 = tmp27 * tmp19
- tmp29 = tmp26 + tmp28
- tl.store(out_ptr3 + (x0), tmp21, xmask)
- tl.store(out_ptr5 + (x0), tmp29, xmask)
- tl.store(out_ptr0 + (x0), tmp13, xmask)
- tl.store(out_ptr1 + (x0), tmp14, xmask)
- # kernel path: /tmp/torchinductor_chilli/qw/cqwedz2pcqdgqlxggudqdyaimhwkci7xpj64jg25hdihc6bvvrat.py
- # Source Nodes: [x_1, x_2], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- # x_1 => add_1, add_4, mul, mul_6, rsqrt, sub, var_mean
- # x_2 => relu
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_relu_5(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, XBLOCK : tl.constexpr):
- xnumel = 831744
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = xindex
- x0 = xindex % 64
- tmp0 = tl.load(in_out_ptr0 + (x2), xmask)
- tmp1 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
- tmp3 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp10 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp12 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
- tmp2 = tmp0 - tmp1
- tmp4 = 12996.0
- tmp5 = tmp3 / tmp4
- tmp6 = 1e-05
- tmp7 = tmp5 + tmp6
- tmp8 = libdevice.rsqrt(tmp7)
- tmp9 = tmp2 * tmp8
- tmp11 = tmp9 * tmp10
- tmp13 = tmp11 + tmp12
- tmp14 = triton_helpers.maximum(0, tmp13)
- tl.store(in_out_ptr0 + (x2), tmp14, xmask)
- # kernel path: /tmp/torchinductor_chilli/kc/ckc6pbzy63vjkiqb3hj3kfrrcfnpzyr4o6bd26275bq73l3asf3q.py
- # Source Nodes: [x_1, x_2, x_3], Original ATen: [aten._native_batch_norm_legit_functional, aten.max_pool2d_with_indices, aten.relu]
- # x_1 => add_1, add_4, mul, mul_6, rsqrt, sub, var_mean
- # x_2 => relu
- # x_3 => max_pool2d_with_indices
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_max_pool2d_with_indices_relu_6(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
- xnumel = 207936
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = (xindex // 3648)
- x1 = (xindex // 64) % 57
- x0 = xindex % 64
- x4 = xindex
- tmp0 = (-1) + (2*x2)
- tmp1 = tl.full([1], 0, tl.int64)
- tmp2 = tmp0 >= tmp1
- tmp3 = tl.full([1], 114, tl.int64)
- tmp4 = tmp0 < tmp3
- tmp5 = tmp2 & tmp4
- tmp6 = (-1) + (2*x1)
- tmp7 = tmp6 >= tmp1
- tmp8 = tmp6 < tmp3
- tmp9 = tmp7 & tmp8
- tmp10 = tmp5 & tmp9
- tmp11 = tl.load(in_ptr0 + ((-7360) + x0 + (128*x1) + (14592*x2)), tmp10 & xmask, other=0.0)
- tmp12 = tl.full(tmp11.shape, float("-inf"), tmp11.dtype)
- tmp13 = tl.where(tmp10, tmp11, tmp12)
- tmp14 = 2*x1
- tmp15 = tmp14 >= tmp1
- tmp16 = tmp14 < tmp3
- tmp17 = tmp15 & tmp16
- tmp18 = tmp5 & tmp17
- tmp19 = tl.load(in_ptr0 + ((-7296) + x0 + (128*x1) + (14592*x2)), tmp18 & xmask, other=0.0)
- tmp20 = tl.full(tmp19.shape, float("-inf"), tmp19.dtype)
- tmp21 = tl.where(tmp18, tmp19, tmp20)
- tmp22 = triton_helpers.maximum(tmp21, tmp13)
- tmp23 = 1 + (2*x1)
- tmp24 = tmp23 >= tmp1
- tmp25 = tmp23 < tmp3
- tmp26 = tmp24 & tmp25
- tmp27 = tmp5 & tmp26
- tmp28 = tl.load(in_ptr0 + ((-7232) + x0 + (128*x1) + (14592*x2)), tmp27 & xmask, other=0.0)
- tmp29 = tl.full(tmp28.shape, float("-inf"), tmp28.dtype)
- tmp30 = tl.where(tmp27, tmp28, tmp29)
- tmp31 = triton_helpers.maximum(tmp30, tmp22)
- tmp32 = 2*x2
- tmp33 = tmp32 >= tmp1
- tmp34 = tmp32 < tmp3
- tmp35 = tmp33 & tmp34
- tmp36 = tmp35 & tmp9
- tmp37 = tl.load(in_ptr0 + ((-64) + x0 + (128*x1) + (14592*x2)), tmp36 & xmask, other=0.0)
- tmp38 = tl.full(tmp37.shape, float("-inf"), tmp37.dtype)
- tmp39 = tl.where(tmp36, tmp37, tmp38)
- tmp40 = triton_helpers.maximum(tmp39, tmp31)
- tmp41 = tmp35 & tmp17
- tmp42 = tl.load(in_ptr0 + (x0 + (128*x1) + (14592*x2)), tmp41 & xmask, other=0.0)
- tmp43 = tl.full(tmp42.shape, float("-inf"), tmp42.dtype)
- tmp44 = tl.where(tmp41, tmp42, tmp43)
- tmp45 = triton_helpers.maximum(tmp44, tmp40)
- tmp46 = tmp35 & tmp26
- tmp47 = tl.load(in_ptr0 + (64 + x0 + (128*x1) + (14592*x2)), tmp46 & xmask, other=0.0)
- tmp48 = tl.full(tmp47.shape, float("-inf"), tmp47.dtype)
- tmp49 = tl.where(tmp46, tmp47, tmp48)
- tmp50 = triton_helpers.maximum(tmp49, tmp45)
- tmp51 = 1 + (2*x2)
- tmp52 = tmp51 >= tmp1
- tmp53 = tmp51 < tmp3
- tmp54 = tmp52 & tmp53
- tmp55 = tmp54 & tmp9
- tmp56 = tl.load(in_ptr0 + (7232 + x0 + (128*x1) + (14592*x2)), tmp55 & xmask, other=0.0)
- tmp57 = tl.full(tmp56.shape, float("-inf"), tmp56.dtype)
- tmp58 = tl.where(tmp55, tmp56, tmp57)
- tmp59 = triton_helpers.maximum(tmp58, tmp50)
- tmp60 = tmp54 & tmp17
- tmp61 = tl.load(in_ptr0 + (7296 + x0 + (128*x1) + (14592*x2)), tmp60 & xmask, other=0.0)
- tmp62 = tl.full(tmp61.shape, float("-inf"), tmp61.dtype)
- tmp63 = tl.where(tmp60, tmp61, tmp62)
- tmp64 = triton_helpers.maximum(tmp63, tmp59)
- tmp65 = tmp54 & tmp26
- tmp66 = tl.load(in_ptr0 + (7360 + x0 + (128*x1) + (14592*x2)), tmp65 & xmask, other=0.0)
- tmp67 = tl.full(tmp66.shape, float("-inf"), tmp66.dtype)
- tmp68 = tl.where(tmp65, tmp66, tmp67)
- tmp69 = triton_helpers.maximum(tmp68, tmp64)
- tl.store(out_ptr0 + (x4), tmp69, xmask)
- # kernel path: /tmp/torchinductor_chilli/vk/cvkspjpsscasoaaqpr4bknw6rubi6rs377kkeeghxsldfkeoncma.py
- # Source Nodes: [out], Original ATen: [aten.convolution]
- # out => convolution_1
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused_convolution_7(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
- ynumel = 4096
- xnumel = 9
- yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
- yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
- ymask = yindex < ynumel
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- x2 = xindex
- y3 = yindex
- y0 = yindex % 64
- y1 = (yindex // 64)
- tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
- tl.store(out_ptr0 + (y0 + (64*x2) + (576*y1)), tmp0, xmask)
- # kernel path: /tmp/torchinductor_chilli/od/codvnm4apnzie5266x5p4vcefr4uplcosymmy6te7w4m4disvhsf.py
- # Source Nodes: [out_1], Original ATen: [aten._native_batch_norm_legit_functional]
- # out_1 => add_7, add_8, mul_10, mul_11, mul_12, mul_8, mul_9, var_mean_1
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_red_fused__native_batch_norm_legit_functional_8(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
- xnumel = 64
- rnumel = 3249
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- rbase = tl.arange(0, RBLOCK)[None, :]
- x0 = xindex
- tmp2_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
- tmp2_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
- tmp2_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
- for roffset in range(0, rnumel, RBLOCK):
- rindex = roffset + rbase
- rmask = rindex < rnumel
- r1 = rindex
- tmp0 = tl.load(in_ptr0 + (x0 + (64*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0)
- tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
- tmp2_mean_next, tmp2_m2_next, tmp2_weight_next = triton_helpers.welford_reduce(
- tmp1, tmp2_mean, tmp2_m2, tmp2_weight, roffset == 0
- )
- tmp2_mean = tl.where(rmask & xmask, tmp2_mean_next, tmp2_mean)
- tmp2_m2 = tl.where(rmask & xmask, tmp2_m2_next, tmp2_m2)
- tmp2_weight = tl.where(rmask & xmask, tmp2_weight_next, tmp2_weight)
- tmp2_tmp, tmp3_tmp, tmp4_tmp = triton_helpers.welford(
- tmp2_mean, tmp2_m2, tmp2_weight, 1
- )
- tmp2 = tmp2_tmp[:, None]
- tmp3 = tmp3_tmp[:, None]
- tmp4 = tmp4_tmp[:, None]
- tl.store(out_ptr0 + (x0), tmp2, xmask)
- tl.store(out_ptr1 + (x0), tmp3, xmask)
- tmp7 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp16 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp5 = 0.1
- tmp6 = tmp2 * tmp5
- tmp8 = 0.9
- tmp9 = tmp7 * tmp8
- tmp10 = tmp6 + tmp9
- tmp11 = 3249.0
- tmp12 = tmp3 / tmp11
- tmp13 = 1.000307881773399
- tmp14 = tmp12 * tmp13
- tmp15 = tmp14 * tmp5
- tmp17 = tmp16 * tmp8
- tmp18 = tmp15 + tmp17
- tl.store(out_ptr3 + (x0), tmp10, xmask)
- tl.store(out_ptr5 + (x0), tmp18, xmask)
- # kernel path: /tmp/torchinductor_chilli/lo/cloanpjqs4btyhgyzyxhsxwzurvxpo2wa465u65kfdpu6mw4chyr.py
- # Source Nodes: [out_1, out_2], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- # out_1 => add_6, add_9, mul_13, mul_7, rsqrt_1, sub_1, var_mean_1
- # out_2 => relu_1
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_relu_9(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
- xnumel = 207936
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = xindex
- x0 = xindex % 64
- tmp0 = tl.load(in_ptr0 + (x2), xmask)
- tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
- tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
- tmp2 = tmp0 - tmp1
- tmp4 = 3249.0
- tmp5 = tmp3 / tmp4
- tmp6 = 1e-05
- tmp7 = tmp5 + tmp6
- tmp8 = libdevice.rsqrt(tmp7)
- tmp9 = tmp2 * tmp8
- tmp11 = tmp9 * tmp10
- tmp13 = tmp11 + tmp12
- tmp14 = triton_helpers.maximum(0, tmp13)
- tl.store(out_ptr0 + (x2), tmp14, xmask)
- # kernel path: /tmp/torchinductor_chilli/vt/cvt4hpbqmxmk4aanytwt546bgwqi4hivd5oal6ks3lpmczefgpzc.py
- # Source Nodes: [out_4, out_5, out_6], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- # out_4 => add_11, add_14, mul_14, mul_20, rsqrt_2, sub_2, var_mean_2
- # out_5 => add_15
- # out_6 => relu_2
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_add_relu_10(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, XBLOCK : tl.constexpr):
- xnumel = 207936
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = xindex
- x0 = xindex % 64
- tmp0 = tl.load(in_ptr0 + (x2), xmask)
- tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
- tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
- tmp14 = tl.load(in_out_ptr0 + (x2), xmask)
- tmp2 = tmp0 - tmp1
- tmp4 = 3249.0
- tmp5 = tmp3 / tmp4
- tmp6 = 1e-05
- tmp7 = tmp5 + tmp6
- tmp8 = libdevice.rsqrt(tmp7)
- tmp9 = tmp2 * tmp8
- tmp11 = tmp9 * tmp10
- tmp13 = tmp11 + tmp12
- tmp15 = tmp13 + tmp14
- tmp16 = triton_helpers.maximum(0, tmp15)
- tl.store(in_out_ptr0 + (x2), tmp16, xmask)
- # kernel path: /tmp/torchinductor_chilli/2x/c2xz42gsxbn7afzhg3pitrfsomlnmlcqformp7rf27frh7ow6fjy.py
- # Source Nodes: [out_14], Original ATen: [aten.convolution]
- # out_14 => convolution_5
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused_convolution_11(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
- ynumel = 8192
- xnumel = 9
- yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
- yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
- ymask = yindex < ynumel
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- x2 = xindex
- y3 = yindex
- y0 = yindex % 64
- y1 = (yindex // 64)
- tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
- tl.store(out_ptr0 + (y0 + (64*x2) + (576*y1)), tmp0, xmask)
- # kernel path: /tmp/torchinductor_chilli/yx/cyxk35gzq34tsaipo6wzfghuopwljvfsgir35jcovgkosngo2ity.py
- # Source Nodes: [out_15], Original ATen: [aten._native_batch_norm_legit_functional]
- # out_15 => add_29, add_30, mul_36, mul_37, mul_38, mul_39, mul_40, var_mean_5
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_per_fused__native_batch_norm_legit_functional_12(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, rnumel):
- xnumel = 128
- XBLOCK: tl.constexpr = 1
- rnumel = 841
- RBLOCK: tl.constexpr = 1024
- xoffset = tl.program_id(0) * XBLOCK
- xindex = tl.full([1], xoffset, tl.int32)
- xmask = xindex < xnumel
- rindex = tl.arange(0, RBLOCK)[:]
- roffset = 0
- rmask = rindex < rnumel
- r1 = rindex
- x0 = xindex
- tmp0 = tl.load(in_ptr0 + (x0 + (128*r1)), rmask & xmask, other=0.0)
- tmp19 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp28 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp1 = tl.broadcast_to(tmp0, [RBLOCK])
- tmp3 = tl.where(rmask & xmask, tmp1, 0)
- tmp4 = tl.broadcast_to(tmp1, [RBLOCK])
- tmp6 = tl.where(rmask & xmask, tmp4, 0)
- tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0))
- tmp8 = tl.full([1], 841, tl.int32)
- tmp9 = tmp8.to(tl.float32)
- tmp10 = tmp7 / tmp9
- tmp11 = tmp1 - tmp10
- tmp12 = tmp11 * tmp11
- tmp13 = tl.broadcast_to(tmp12, [RBLOCK])
- tmp15 = tl.where(rmask & xmask, tmp13, 0)
- tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp15, 0))
- tmp17 = 0.1
- tmp18 = tmp10 * tmp17
- tmp20 = 0.9
- tmp21 = tmp19 * tmp20
- tmp22 = tmp18 + tmp21
- tmp23 = 841.0
- tmp24 = tmp16 / tmp23
- tmp25 = 1.0011904761904762
- tmp26 = tmp24 * tmp25
- tmp27 = tmp26 * tmp17
- tmp29 = tmp28 * tmp20
- tmp30 = tmp27 + tmp29
- tl.store(out_ptr3 + (x0), tmp22, xmask)
- tl.store(out_ptr5 + (x0), tmp30, xmask)
- tl.store(out_ptr0 + (x0), tmp10, xmask)
- tl.store(out_ptr1 + (x0), tmp16, xmask)
- # kernel path: /tmp/torchinductor_chilli/vv/cvvk47b5xgybm23ez23a7raivkf2sv2kr4orbokbjadnck4nft3z.py
- # Source Nodes: [out_15, out_16], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- # out_15 => add_28, add_31, mul_35, mul_41, rsqrt_5, sub_5, var_mean_5
- # out_16 => relu_5
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_relu_13(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
- xnumel = 107648
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = xindex
- x0 = xindex % 128
- tmp0 = tl.load(in_ptr0 + (x2), xmask)
- tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
- tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
- tmp2 = tmp0 - tmp1
- tmp4 = 841.0
- tmp5 = tmp3 / tmp4
- tmp6 = 1e-05
- tmp7 = tmp5 + tmp6
- tmp8 = libdevice.rsqrt(tmp7)
- tmp9 = tmp2 * tmp8
- tmp11 = tmp9 * tmp10
- tmp13 = tmp11 + tmp12
- tmp14 = triton_helpers.maximum(0, tmp13)
- tl.store(out_ptr0 + (x2), tmp14, xmask)
- # kernel path: /tmp/torchinductor_chilli/a2/ca2vg66dc6tbazekmxs7fptl72wa3nskx3lnbicxijilno3uoule.py
- # Source Nodes: [out_15, out_16, out_17], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- # out_15 => add_28, add_31, mul_35, mul_41, rsqrt_5, sub_5, var_mean_5
- # out_16 => relu_5
- # out_17 => convolution_6
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_14(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
- ynumel = 16384
- xnumel = 9
- yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
- yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
- ymask = yindex < ynumel
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- x2 = xindex
- y3 = yindex
- y0 = yindex % 128
- y1 = (yindex // 128)
- tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
- tl.store(out_ptr0 + (y0 + (128*x2) + (1152*y1)), tmp0, xmask)
- # kernel path: /tmp/torchinductor_chilli/6t/c6t2wfxi5lmlky6np3qaojn7zwyvd6xsx545pckbcrchazw4yg5z.py
- # Source Nodes: [identity, out_18, out_19, out_20], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- # identity => add_38, add_41, mul_49, mul_55, rsqrt_7, sub_7, var_mean_7
- # out_18 => add_33, add_36, mul_42, mul_48, rsqrt_6, sub_6, var_mean_6
- # out_19 => add_42
- # out_20 => relu_6
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_add_relu_15(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, xnumel, XBLOCK : tl.constexpr):
- xnumel = 107648
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = xindex
- x0 = xindex % 128
- tmp0 = tl.load(in_ptr0 + (x2), xmask)
- tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
- tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
- tmp14 = tl.load(in_ptr5 + (x2), xmask)
- tmp15 = tl.load(in_ptr6 + (x0), xmask, eviction_policy='evict_last')
- tmp17 = tl.load(in_ptr7 + (x0), xmask, eviction_policy='evict_last')
- tmp22 = tl.load(in_ptr8 + (x0), xmask, eviction_policy='evict_last')
- tmp24 = tl.load(in_ptr9 + (x0), xmask, eviction_policy='evict_last')
- tmp2 = tmp0 - tmp1
- tmp4 = 841.0
- tmp5 = tmp3 / tmp4
- tmp6 = 1e-05
- tmp7 = tmp5 + tmp6
- tmp8 = libdevice.rsqrt(tmp7)
- tmp9 = tmp2 * tmp8
- tmp11 = tmp9 * tmp10
- tmp13 = tmp11 + tmp12
- tmp16 = tmp14 - tmp15
- tmp18 = tmp17 / tmp4
- tmp19 = tmp18 + tmp6
- tmp20 = libdevice.rsqrt(tmp19)
- tmp21 = tmp16 * tmp20
- tmp23 = tmp21 * tmp22
- tmp25 = tmp23 + tmp24
- tmp26 = tmp13 + tmp25
- tmp27 = triton_helpers.maximum(0, tmp26)
- tl.store(in_out_ptr0 + (x2), tmp27, xmask)
- # kernel path: /tmp/torchinductor_chilli/26/c26rt5bj4lhziddwgxx2yplkvhjuuc7bht32mqdzdeewrv6zwvso.py
- # Source Nodes: [out_25, out_26, out_27], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- # out_25 => add_49, add_52, mul_63, mul_69, rsqrt_9, sub_9, var_mean_9
- # out_26 => add_53
- # out_27 => relu_8
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_add_relu_16(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, XBLOCK : tl.constexpr):
- xnumel = 107648
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = xindex
- x0 = xindex % 128
- tmp0 = tl.load(in_ptr0 + (x2), xmask)
- tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
- tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
- tmp14 = tl.load(in_out_ptr0 + (x2), xmask)
- tmp2 = tmp0 - tmp1
- tmp4 = 841.0
- tmp5 = tmp3 / tmp4
- tmp6 = 1e-05
- tmp7 = tmp5 + tmp6
- tmp8 = libdevice.rsqrt(tmp7)
- tmp9 = tmp2 * tmp8
- tmp11 = tmp9 * tmp10
- tmp13 = tmp11 + tmp12
- tmp15 = tmp13 + tmp14
- tmp16 = triton_helpers.maximum(0, tmp15)
- tl.store(in_out_ptr0 + (x2), tmp16, xmask)
- # kernel path: /tmp/torchinductor_chilli/z2/cz2hmaiqkwsixgjuxaqmaywuxpi2lcassripzkjabnb4j44mcjed.py
- # Source Nodes: [out_28], Original ATen: [aten.convolution]
- # out_28 => convolution_10
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused_convolution_17(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
- ynumel = 32768
- xnumel = 9
- yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
- yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
- ymask = yindex < ynumel
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- x2 = xindex
- y3 = yindex
- y0 = yindex % 128
- y1 = (yindex // 128)
- tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
- tl.store(out_ptr0 + (y0 + (128*x2) + (1152*y1)), tmp0, xmask)
- # kernel path: /tmp/torchinductor_chilli/5m/c5muflxhnwscn5dsarqbe5g35q5b3e7tjk5aamn42344ztmgffpc.py
- # Source Nodes: [out_29], Original ATen: [aten._native_batch_norm_legit_functional]
- # out_29 => add_56, add_57, mul_71, mul_72, mul_73, mul_74, mul_75, var_mean_10
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_per_fused__native_batch_norm_legit_functional_18(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, rnumel, XBLOCK : tl.constexpr):
- xnumel = 256
- rnumel = 225
- RBLOCK: tl.constexpr = 256
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- rindex = tl.arange(0, RBLOCK)[None, :]
- roffset = 0
- rmask = rindex < rnumel
- r1 = rindex
- x0 = xindex
- tmp0 = tl.load(in_ptr0 + (x0 + (256*r1)), rmask & xmask, other=0.0)
- tmp19 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp28 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
- tmp3 = tl.where(rmask & xmask, tmp1, 0)
- tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
- tmp6 = tl.where(rmask & xmask, tmp4, 0)
- tmp7 = tl.sum(tmp6, 1)[:, None]
- tmp8 = tl.full([XBLOCK, 1], 225, tl.int32)
- tmp9 = tmp8.to(tl.float32)
- tmp10 = tmp7 / tmp9
- tmp11 = tmp1 - tmp10
- tmp12 = tmp11 * tmp11
- tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
- tmp15 = tl.where(rmask & xmask, tmp13, 0)
- tmp16 = tl.sum(tmp15, 1)[:, None]
- tmp17 = 0.1
- tmp18 = tmp10 * tmp17
- tmp20 = 0.9
- tmp21 = tmp19 * tmp20
- tmp22 = tmp18 + tmp21
- tmp23 = 225.0
- tmp24 = tmp16 / tmp23
- tmp25 = 1.0044642857142858
- tmp26 = tmp24 * tmp25
- tmp27 = tmp26 * tmp17
- tmp29 = tmp28 * tmp20
- tmp30 = tmp27 + tmp29
- tl.store(out_ptr3 + (x0), tmp22, xmask)
- tl.store(out_ptr5 + (x0), tmp30, xmask)
- tl.store(out_ptr0 + (x0), tmp10, xmask)
- tl.store(out_ptr1 + (x0), tmp16, xmask)
- # kernel path: /tmp/torchinductor_chilli/o6/co6ycddj5sgoncwk3rdnpd65b5kx2eufjx4obfxbpg7ev3hii5ti.py
- # Source Nodes: [out_29, out_30], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- # out_29 => add_55, add_58, mul_70, mul_76, rsqrt_10, sub_10, var_mean_10
- # out_30 => relu_9
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_relu_19(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
- xnumel = 57600
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = xindex
- x0 = xindex % 256
- tmp0 = tl.load(in_ptr0 + (x2), xmask)
- tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
- tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
- tmp2 = tmp0 - tmp1
- tmp4 = 225.0
- tmp5 = tmp3 / tmp4
- tmp6 = 1e-05
- tmp7 = tmp5 + tmp6
- tmp8 = libdevice.rsqrt(tmp7)
- tmp9 = tmp2 * tmp8
- tmp11 = tmp9 * tmp10
- tmp13 = tmp11 + tmp12
- tmp14 = triton_helpers.maximum(0, tmp13)
- tl.store(out_ptr0 + (x2), tmp14, xmask)
- # kernel path: /tmp/torchinductor_chilli/nf/cnfgqtfbw5smkkagwken4rsi6rzry342age3cxo5oozd2ewyyhez.py
- # Source Nodes: [out_29, out_30, out_31], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- # out_29 => add_55, add_58, mul_70, mul_76, rsqrt_10, sub_10, var_mean_10
- # out_30 => relu_9
- # out_31 => convolution_11
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_20(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
- ynumel = 65536
- xnumel = 9
- yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
- yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
- ymask = yindex < ynumel
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- x2 = xindex
- y3 = yindex
- y0 = yindex % 256
- y1 = (yindex // 256)
- tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
- tl.store(out_ptr0 + (y0 + (256*x2) + (2304*y1)), tmp0, xmask)
- # kernel path: /tmp/torchinductor_chilli/ti/ctizppkxqrjhhsm7luucj6whivyl2vv4fjjdpas7rporotca2pry.py
- # Source Nodes: [identity_1, out_32, out_33, out_34], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- # identity_1 => add_65, add_68, mul_84, mul_90, rsqrt_12, sub_12, var_mean_12
- # out_32 => add_60, add_63, mul_77, mul_83, rsqrt_11, sub_11, var_mean_11
- # out_33 => add_69
- # out_34 => relu_10
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_add_relu_21(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, xnumel, XBLOCK : tl.constexpr):
- xnumel = 57600
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = xindex
- x0 = xindex % 256
- tmp0 = tl.load(in_ptr0 + (x2), xmask)
- tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
- tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
- tmp14 = tl.load(in_ptr5 + (x2), xmask)
- tmp15 = tl.load(in_ptr6 + (x0), xmask, eviction_policy='evict_last')
- tmp17 = tl.load(in_ptr7 + (x0), xmask, eviction_policy='evict_last')
- tmp22 = tl.load(in_ptr8 + (x0), xmask, eviction_policy='evict_last')
- tmp24 = tl.load(in_ptr9 + (x0), xmask, eviction_policy='evict_last')
- tmp2 = tmp0 - tmp1
- tmp4 = 225.0
- tmp5 = tmp3 / tmp4
- tmp6 = 1e-05
- tmp7 = tmp5 + tmp6
- tmp8 = libdevice.rsqrt(tmp7)
- tmp9 = tmp2 * tmp8
- tmp11 = tmp9 * tmp10
- tmp13 = tmp11 + tmp12
- tmp16 = tmp14 - tmp15
- tmp18 = tmp17 / tmp4
- tmp19 = tmp18 + tmp6
- tmp20 = libdevice.rsqrt(tmp19)
- tmp21 = tmp16 * tmp20
- tmp23 = tmp21 * tmp22
- tmp25 = tmp23 + tmp24
- tmp26 = tmp13 + tmp25
- tmp27 = triton_helpers.maximum(0, tmp26)
- tl.store(in_out_ptr0 + (x2), tmp27, xmask)
- # kernel path: /tmp/torchinductor_chilli/h4/ch4hlqd3tqqqn4auweqr3lccybiiv5c6i2hmtylkhw3i3ac6nvem.py
- # Source Nodes: [out_39, out_40, out_41], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- # out_39 => add_76, add_79, mul_104, mul_98, rsqrt_14, sub_14, var_mean_14
- # out_40 => add_80
- # out_41 => relu_12
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_add_relu_22(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, XBLOCK : tl.constexpr):
- xnumel = 57600
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = xindex
- x0 = xindex % 256
- tmp0 = tl.load(in_ptr0 + (x2), xmask)
- tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp3 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp10 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last')
- tmp12 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
- tmp14 = tl.load(in_out_ptr0 + (x2), xmask)
- tmp2 = tmp0 - tmp1
- tmp4 = 225.0
- tmp5 = tmp3 / tmp4
- tmp6 = 1e-05
- tmp7 = tmp5 + tmp6
- tmp8 = libdevice.rsqrt(tmp7)
- tmp9 = tmp2 * tmp8
- tmp11 = tmp9 * tmp10
- tmp13 = tmp11 + tmp12
- tmp15 = tmp13 + tmp14
- tmp16 = triton_helpers.maximum(0, tmp15)
- tl.store(in_out_ptr0 + (x2), tmp16, xmask)
- # kernel path: /tmp/torchinductor_chilli/aw/cawomzmzglbre7ilhcarrqi7hjfndlnawqmuj6pdrjwpbvtwu67i.py
- # Source Nodes: [out_42], Original ATen: [aten.convolution]
- # out_42 => convolution_15
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused_convolution_23(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
- ynumel = 131072
- xnumel = 9
- yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
- yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
- ymask = yindex < ynumel
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- x2 = xindex
- y3 = yindex
- y0 = yindex % 256
- y1 = (yindex // 256)
- tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
- tl.store(out_ptr0 + (y0 + (256*x2) + (2304*y1)), tmp0, xmask)
- # kernel path: /tmp/torchinductor_chilli/fl/cflwvwmk3cuzfonagfvse3fly3w2z6wozgcbuc5yjwib3an7gogg.py
- # Source Nodes: [out_43], Original ATen: [aten._native_batch_norm_legit_functional]
- # out_43 => add_83, add_84, mul_106, mul_107, mul_108, mul_109, mul_110, var_mean_15
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_per_fused__native_batch_norm_legit_functional_24(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, rnumel, XBLOCK : tl.constexpr):
- xnumel = 512
- rnumel = 64
- RBLOCK: tl.constexpr = 64
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- rindex = tl.arange(0, RBLOCK)[None, :]
- roffset = 0
- rmask = rindex < rnumel
- r1 = rindex
- x0 = xindex
- tmp0 = tl.load(in_ptr0 + (x0 + (512*r1)), rmask & xmask, other=0.0)
- tmp19 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp28 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
- tmp3 = tl.where(rmask & xmask, tmp1, 0)
- tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
- tmp6 = tl.where(rmask & xmask, tmp4, 0)
- tmp7 = tl.sum(tmp6, 1)[:, None]
- tmp8 = tl.full([XBLOCK, 1], 64, tl.int32)
- tmp9 = tmp8.to(tl.float32)
- tmp10 = tmp7 / tmp9
- tmp11 = tmp1 - tmp10
- tmp12 = tmp11 * tmp11
- tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
- tmp15 = tl.where(rmask & xmask, tmp13, 0)
- tmp16 = tl.sum(tmp15, 1)[:, None]
- tmp17 = 0.1
- tmp18 = tmp10 * tmp17
- tmp20 = 0.9
- tmp21 = tmp19 * tmp20
- tmp22 = tmp18 + tmp21
- tmp23 = 64.0
- tmp24 = tmp16 / tmp23
- tmp25 = 1.0158730158730158
- tmp26 = tmp24 * tmp25
- tmp27 = tmp26 * tmp17
- tmp29 = tmp28 * tmp20
- tmp30 = tmp27 + tmp29
- tl.store(out_ptr3 + (x0), tmp22, xmask)
- tl.store(out_ptr5 + (x0), tmp30, xmask)
- tl.store(out_ptr0 + (x0), tmp10, xmask)
- tl.store(out_ptr1 + (x0), tmp16, xmask)
- # kernel path: /tmp/torchinductor_chilli/sy/csyscyubm24wopyxzglh7ys4ywxr3iuludm3tpzws4m5xt4glbd2.py
- # Source Nodes: [out_43, out_44], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- # out_43 => add_82, add_85, mul_105, mul_111, rsqrt_15, sub_15, var_mean_15
- # out_44 => relu_13
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_relu_25(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
- xnumel = 32768
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = xindex
- x0 = xindex % 512
- tmp0 = tl.load(in_ptr0 + (x2), None)
- tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last')
- tmp3 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
- tmp10 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
- tmp12 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
- tmp2 = tmp0 - tmp1
- tmp4 = 64.0
- tmp5 = tmp3 / tmp4
- tmp6 = 1e-05
- tmp7 = tmp5 + tmp6
- tmp8 = libdevice.rsqrt(tmp7)
- tmp9 = tmp2 * tmp8
- tmp11 = tmp9 * tmp10
- tmp13 = tmp11 + tmp12
- tmp14 = triton_helpers.maximum(0, tmp13)
- tl.store(out_ptr0 + (x2), tmp14, None)
- # kernel path: /tmp/torchinductor_chilli/zo/czoc4yry3r2czazn4sagx7wsfwtp3hosm7nwr7fekbym23kvtnfu.py
- # Source Nodes: [out_43, out_44, out_45], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- # out_43 => add_82, add_85, mul_105, mul_111, rsqrt_15, sub_15, var_mean_15
- # out_44 => relu_13
- # out_45 => convolution_16
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_26(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
- ynumel = 262144
- xnumel = 9
- yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
- yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
- ymask = yindex < ynumel
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- x2 = xindex
- y3 = yindex
- y0 = yindex % 512
- y1 = (yindex // 512)
- tmp0 = tl.load(in_ptr0 + (x2 + (9*y3)), xmask, eviction_policy='evict_last')
- tl.store(out_ptr0 + (y0 + (512*x2) + (4608*y1)), tmp0, xmask)
- # kernel path: /tmp/torchinductor_chilli/x3/cx3julbo4ui5eikjj253cy76wq3pd3umd6aj6cjvz5lg7zerlerv.py
- # Source Nodes: [identity_2, out_46, out_47, out_48], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- # identity_2 => add_92, add_95, mul_119, mul_125, rsqrt_17, sub_17, var_mean_17
- # out_46 => add_87, add_90, mul_112, mul_118, rsqrt_16, sub_16, var_mean_16
- # out_47 => add_96
- # out_48 => relu_14
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused__native_batch_norm_legit_functional_add_relu_27(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, xnumel, XBLOCK : tl.constexpr):
- xnumel = 32768
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- x2 = xindex
- x0 = xindex % 512
- tmp0 = tl.load(in_ptr0 + (x2), None)
- tmp1 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last')
- tmp3 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
- tmp10 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
- tmp12 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
- tmp14 = tl.load(in_ptr5 + (x2), None)
- tmp15 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
- tmp17 = tl.load(in_ptr7 + (x0), None, eviction_policy='evict_last')
- tmp22 = tl.load(in_ptr8 + (x0), None, eviction_policy='evict_last')
- tmp24 = tl.load(in_ptr9 + (x0), None, eviction_policy='evict_last')
- tmp2 = tmp0 - tmp1
- tmp4 = 64.0
- tmp5 = tmp3 / tmp4
- tmp6 = 1e-05
- tmp7 = tmp5 + tmp6
- tmp8 = libdevice.rsqrt(tmp7)
- tmp9 = tmp2 * tmp8
- tmp11 = tmp9 * tmp10
- tmp13 = tmp11 + tmp12
- tmp16 = tmp14 - tmp15
- tmp18 = tmp17 / tmp4
- tmp19 = tmp18 + tmp6
- tmp20 = libdevice.rsqrt(tmp19)
- tmp21 = tmp16 * tmp20
- tmp23 = tmp21 * tmp22
- tmp25 = tmp23 + tmp24
- tmp26 = tmp13 + tmp25
- tmp27 = triton_helpers.maximum(0, tmp26)
- tl.store(in_out_ptr0 + (x2), tmp27, None)
- # kernel path: /tmp/torchinductor_chilli/mr/cmr7ut36zrwmn4s73ifiyzecupyoozfkyvscgohgilbdwcv6vdgv.py
- # Source Nodes: [out_53, out_54, out_55, x_4], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.mean, aten.relu]
- # out_53 => add_103, add_104, add_105, add_106, mul_133, mul_134, mul_135, mul_136, mul_137, mul_138, mul_139, rsqrt_19, sub_19, var_mean_19
- # out_54 => add_107
- # out_55 => relu_16
- # x_4 => mean
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_per_fused__native_batch_norm_legit_functional_add_mean_relu_28(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr3, out_ptr5, xnumel, rnumel, XBLOCK : tl.constexpr):
- xnumel = 512
- rnumel = 64
- RBLOCK: tl.constexpr = 64
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
- xmask = xindex < xnumel
- rindex = tl.arange(0, RBLOCK)[None, :]
- roffset = 0
- rmask = rindex < rnumel
- r1 = rindex
- x0 = xindex
- tmp0 = tl.load(in_ptr0 + (x0 + (512*r1)), rmask & xmask, other=0.0)
- tmp24 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last')
- tmp26 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
- tmp28 = tl.load(in_ptr3 + (x0 + (512*r1)), rmask & xmask, other=0.0)
- tmp38 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
- tmp45 = tl.load(in_ptr5 + (x0), xmask, eviction_policy='evict_last')
- tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
- tmp3 = tl.where(rmask & xmask, tmp1, 0)
- tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
- tmp6 = tl.where(rmask & xmask, tmp4, 0)
- tmp7 = tl.sum(tmp6, 1)[:, None]
- tmp8 = tl.full([XBLOCK, 1], 64, tl.int32)
- tmp9 = tmp8.to(tl.float32)
- tmp10 = tmp7 / tmp9
- tmp11 = tmp1 - tmp10
- tmp12 = tmp11 * tmp11
- tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
- tmp15 = tl.where(rmask & xmask, tmp13, 0)
- tmp16 = tl.sum(tmp15, 1)[:, None]
- tmp17 = tmp0 - tmp10
- tmp18 = 64.0
- tmp19 = tmp16 / tmp18
- tmp20 = 1e-05
- tmp21 = tmp19 + tmp20
- tmp22 = libdevice.rsqrt(tmp21)
- tmp23 = tmp17 * tmp22
- tmp25 = tmp23 * tmp24
- tmp27 = tmp25 + tmp26
- tmp29 = tmp27 + tmp28
- tmp30 = triton_helpers.maximum(0, tmp29)
- tmp31 = tl.broadcast_to(tmp30, [XBLOCK, RBLOCK])
- tmp33 = tl.where(rmask & xmask, tmp31, 0)
- tmp34 = tl.sum(tmp33, 1)[:, None]
- tmp35 = tmp34 / tmp18
- tmp36 = 0.1
- tmp37 = tmp10 * tmp36
- tmp39 = 0.9
- tmp40 = tmp38 * tmp39
- tmp41 = tmp37 + tmp40
- tmp42 = 1.0158730158730158
- tmp43 = tmp19 * tmp42
- tmp44 = tmp43 * tmp36
- tmp46 = tmp45 * tmp39
- tmp47 = tmp44 + tmp46
- tl.debug_barrier()
- tl.store(in_out_ptr0 + (x0), tmp35, xmask)
- tl.store(out_ptr3 + (x0), tmp41, xmask)
- tl.store(out_ptr5 + (x0), tmp47, xmask)
- # kernel path: /tmp/torchinductor_chilli/nj/cnj4bf4ulsnoqw46btyaeowtw6gfffjthwhbb7vulfzsoz43jkdn.py
- # Source Nodes: [x_1], Original ATen: [aten.add]
- # x_1 => add
- import triton
- import triton.language as tl
- from triton.compiler.compiler import AttrsDescriptor
- from torch._inductor.runtime import triton_helpers, triton_heuristics
- from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
- from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
- @triton.jit
- def triton_poi_fused_add_29(in_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
- xnumel = 1
- xoffset = tl.program_id(0) * XBLOCK
- xindex = xoffset + tl.arange(0, XBLOCK)[:]
- xmask = xindex < xnumel
- tmp0 = tl.load(in_ptr0 + (0))
- tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
- tmp2 = tl.full([1], 1, tl.int64)
- tmp3 = tmp1 + tmp2
- tl.store(out_ptr1 + (tl.full([XBLOCK], 0, tl.int32)), tmp3, None)
- def call(args):
- arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1 = args
- args.clear()
- assert_size_stride(arg0_1, (64, 3, 7, 7), (147, 49, 7, 1))
- assert_size_stride(arg1_1, (64, ), (1, ))
- assert_size_stride(arg2_1, (64, ), (1, ))
- assert_size_stride(arg3_1, (64, 64, 3, 3), (576, 9, 3, 1))
- assert_size_stride(arg4_1, (64, ), (1, ))
- assert_size_stride(arg5_1, (64, ), (1, ))
- assert_size_stride(arg6_1, (64, 64, 3, 3), (576, 9, 3, 1))
- assert_size_stride(arg7_1, (64, ), (1, ))
- assert_size_stride(arg8_1, (64, ), (1, ))
- assert_size_stride(arg9_1, (64, 64, 3, 3), (576, 9, 3, 1))
- assert_size_stride(arg10_1, (64, ), (1, ))
- assert_size_stride(arg11_1, (64, ), (1, ))
- assert_size_stride(arg12_1, (64, 64, 3, 3), (576, 9, 3, 1))
- assert_size_stride(arg13_1, (64, ), (1, ))
- assert_size_stride(arg14_1, (64, ), (1, ))
- assert_size_stride(arg15_1, (128, 64, 3, 3), (576, 9, 3, 1))
- assert_size_stride(arg16_1, (128, ), (1, ))
- assert_size_stride(arg17_1, (128, ), (1, ))
- assert_size_stride(arg18_1, (128, 128, 3, 3), (1152, 9, 3, 1))
- assert_size_stride(arg19_1, (128, ), (1, ))
- assert_size_stride(arg20_1, (128, ), (1, ))
- assert_size_stride(arg21_1, (128, 64, 1, 1), (64, 1, 1, 1))
- assert_size_stride(arg22_1, (128, ), (1, ))
- assert_size_stride(arg23_1, (128, ), (1, ))
- assert_size_stride(arg24_1, (128, 128, 3, 3), (1152, 9, 3, 1))
- assert_size_stride(arg25_1, (128, ), (1, ))
- assert_size_stride(arg26_1, (128, ), (1, ))
- assert_size_stride(arg27_1, (128, 128, 3, 3), (1152, 9, 3, 1))
- assert_size_stride(arg28_1, (128, ), (1, ))
- assert_size_stride(arg29_1, (128, ), (1, ))
- assert_size_stride(arg30_1, (256, 128, 3, 3), (1152, 9, 3, 1))
- assert_size_stride(arg31_1, (256, ), (1, ))
- assert_size_stride(arg32_1, (256, ), (1, ))
- assert_size_stride(arg33_1, (256, 256, 3, 3), (2304, 9, 3, 1))
- assert_size_stride(arg34_1, (256, ), (1, ))
- assert_size_stride(arg35_1, (256, ), (1, ))
- assert_size_stride(arg36_1, (256, 128, 1, 1), (128, 1, 1, 1))
- assert_size_stride(arg37_1, (256, ), (1, ))
- assert_size_stride(arg38_1, (256, ), (1, ))
- assert_size_stride(arg39_1, (256, 256, 3, 3), (2304, 9, 3, 1))
- assert_size_stride(arg40_1, (256, ), (1, ))
- assert_size_stride(arg41_1, (256, ), (1, ))
- assert_size_stride(arg42_1, (256, 256, 3, 3), (2304, 9, 3, 1))
- assert_size_stride(arg43_1, (256, ), (1, ))
- assert_size_stride(arg44_1, (256, ), (1, ))
- assert_size_stride(arg45_1, (512, 256, 3, 3), (2304, 9, 3, 1))
- assert_size_stride(arg46_1, (512, ), (1, ))
- assert_size_stride(arg47_1, (512, ), (1, ))
- assert_size_stride(arg48_1, (512, 512, 3, 3), (4608, 9, 3, 1))
- assert_size_stride(arg49_1, (512, ), (1, ))
- assert_size_stride(arg50_1, (512, ), (1, ))
- assert_size_stride(arg51_1, (512, 256, 1, 1), (256, 1, 1, 1))
- assert_size_stride(arg52_1, (512, ), (1, ))
- assert_size_stride(arg53_1, (512, ), (1, ))
- assert_size_stride(arg54_1, (512, 512, 3, 3), (4608, 9, 3, 1))
- assert_size_stride(arg55_1, (512, ), (1, ))
- assert_size_stride(arg56_1, (512, ), (1, ))
- assert_size_stride(arg57_1, (512, 512, 3, 3), (4608, 9, 3, 1))
- assert_size_stride(arg58_1, (512, ), (1, ))
- assert_size_stride(arg59_1, (512, ), (1, ))
- assert_size_stride(arg60_1, (1000, 512), (512, 1))
- assert_size_stride(arg61_1, (1000, ), (1, ))
- assert_size_stride(arg62_1, (64, ), (1, ))
- assert_size_stride(arg63_1, (64, ), (1, ))
- assert_size_stride(arg64_1, (), ())
- assert_size_stride(arg65_1, (64, ), (1, ))
- assert_size_stride(arg66_1, (64, ), (1, ))
- assert_size_stride(arg67_1, (), ())
- assert_size_stride(arg68_1, (64, ), (1, ))
- assert_size_stride(arg69_1, (64, ), (1, ))
- assert_size_stride(arg70_1, (), ())
- assert_size_stride(arg71_1, (64, ), (1, ))
- assert_size_stride(arg72_1, (64, ), (1, ))
- assert_size_stride(arg73_1, (), ())
- assert_size_stride(arg74_1, (64, ), (1, ))
- assert_size_stride(arg75_1, (64, ), (1, ))
- assert_size_stride(arg76_1, (), ())
- assert_size_stride(arg77_1, (128, ), (1, ))
- assert_size_stride(arg78_1, (128, ), (1, ))
- assert_size_stride(arg79_1, (), ())
- assert_size_stride(arg80_1, (128, ), (1, ))
- assert_size_stride(arg81_1, (128, ), (1, ))
- assert_size_stride(arg82_1, (), ())
- assert_size_stride(arg83_1, (128, ), (1, ))
- assert_size_stride(arg84_1, (128, ), (1, ))
- assert_size_stride(arg85_1, (), ())
- assert_size_stride(arg86_1, (128, ), (1, ))
- assert_size_stride(arg87_1, (128, ), (1, ))
- assert_size_stride(arg88_1, (), ())
- assert_size_stride(arg89_1, (128, ), (1, ))
- assert_size_stride(arg90_1, (128, ), (1, ))
- assert_size_stride(arg91_1, (), ())
- assert_size_stride(arg92_1, (256, ), (1, ))
- assert_size_stride(arg93_1, (256, ), (1, ))
- assert_size_stride(arg94_1, (), ())
- assert_size_stride(arg95_1, (256, ), (1, ))
- assert_size_stride(arg96_1, (256, ), (1, ))
- assert_size_stride(arg97_1, (), ())
- assert_size_stride(arg98_1, (256, ), (1, ))
- assert_size_stride(arg99_1, (256, ), (1, ))
- assert_size_stride(arg100_1, (), ())
- assert_size_stride(arg101_1, (256, ), (1, ))
- assert_size_stride(arg102_1, (256, ), (1, ))
- assert_size_stride(arg103_1, (), ())
- assert_size_stride(arg104_1, (256, ), (1, ))
- assert_size_stride(arg105_1, (256, ), (1, ))
- assert_size_stride(arg106_1, (), ())
- assert_size_stride(arg107_1, (512, ), (1, ))
- assert_size_stride(arg108_1, (512, ), (1, ))
- assert_size_stride(arg109_1, (), ())
- assert_size_stride(arg110_1, (512, ), (1, ))
- assert_size_stride(arg111_1, (512, ), (1, ))
- assert_size_stride(arg112_1, (), ())
- assert_size_stride(arg113_1, (512, ), (1, ))
- assert_size_stride(arg114_1, (512, ), (1, ))
- assert_size_stride(arg115_1, (), ())
- assert_size_stride(arg116_1, (512, ), (1, ))
- assert_size_stride(arg117_1, (512, ), (1, ))
- assert_size_stride(arg118_1, (), ())
- assert_size_stride(arg119_1, (512, ), (1, ))
- assert_size_stride(arg120_1, (512, ), (1, ))
- assert_size_stride(arg121_1, (), ())
- assert_size_stride(arg122_1, (1, 3, 228, 228), (155952, 51984, 228, 1))
- with torch.cuda._DeviceGuard(0):
- torch.cuda.set_device(0)
- buf0 = empty_strided_cuda((1, 3, 228, 228), (155952, 1, 684, 3), torch.float32)
- # Source Nodes: [x], Original ATen: [aten.convolution]
- stream0 = get_raw_stream(0)
- triton_poi_fused_convolution_0[grid(3, 51984)](arg122_1, buf0, 3, 51984, XBLOCK=256, YBLOCK=4, num_warps=4, num_stages=1)
- del arg122_1
- buf1 = empty_strided_cuda((64, 3, 7, 7), (147, 1, 21, 3), torch.float32)
- # Source Nodes: [x], Original ATen: [aten.convolution]
- triton_poi_fused_convolution_1[grid(192, 49)](arg0_1, buf1, 192, 49, XBLOCK=32, YBLOCK=32, num_warps=4, num_stages=1)
- del arg0_1
- # Source Nodes: [x], Original ATen: [aten.convolution]
- buf2 = extern_kernels.convolution(buf0, buf1, stride=(2, 2), padding=(3, 3), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf2, (1, 64, 114, 114), (831744, 1, 7296, 64))
- del buf0
- del buf1
- buf3 = empty_strided_cuda((1, 64, 1, 1, 2, 51), (6528, 1, 6528, 6528, 3264, 64), torch.float32)
- buf4 = empty_strided_cuda((1, 64, 1, 1, 2, 51), (6528, 1, 6528, 6528, 3264, 64), torch.float32)
- buf5 = empty_strided_cuda((1, 64, 1, 1, 2, 51), (6528, 1, 6528, 6528, 3264, 64), torch.float32)
- # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_2[grid(6528)](buf2, buf3, buf4, buf5, 6528, 128, XBLOCK=32, num_warps=8, num_stages=1)
- buf6 = empty_strided_cuda((1, 64, 1, 1, 2), (128, 1, 128, 128, 64), torch.float32)
- buf7 = empty_strided_cuda((1, 64, 1, 1, 2), (128, 1, 128, 128, 64), torch.float32)
- buf8 = empty_strided_cuda((1, 64, 1, 1, 2), (128, 1, 128, 128, 64), torch.float32)
- # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_3[grid(128)](buf3, buf4, buf5, buf6, buf7, buf8, 128, 51, XBLOCK=1, num_warps=2, num_stages=1)
- del buf3
- del buf4
- del buf5
- buf9 = empty_strided_cuda((1, 64, 1, 1), (64, 1, 1, 1), torch.float32)
- buf10 = empty_strided_cuda((1, 64, 1, 1), (64, 1, 64, 64), torch.float32)
- # Source Nodes: [x_1], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_4[grid(64)](buf6, buf7, buf8, arg62_1, arg63_1, buf9, buf10, arg62_1, arg63_1, 64, 2, XBLOCK=8, num_warps=2, num_stages=1)
- del arg62_1
- del arg63_1
- buf12 = buf2; del buf2 # reuse
- # Source Nodes: [x_1, x_2], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_relu_5[grid(831744)](buf12, buf9, buf10, arg1_1, arg2_1, 831744, XBLOCK=1024, num_warps=4, num_stages=1)
- del arg1_1
- del arg2_1
- buf13 = empty_strided_cuda((1, 64, 57, 57), (207936, 1, 3648, 64), torch.float32)
- # Source Nodes: [x_1, x_2, x_3], Original ATen: [aten._native_batch_norm_legit_functional, aten.max_pool2d_with_indices, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_max_pool2d_with_indices_relu_6[grid(207936)](buf12, buf13, 207936, XBLOCK=512, num_warps=8, num_stages=1)
- del buf12
- buf14 = empty_strided_cuda((64, 64, 3, 3), (576, 1, 192, 64), torch.float32)
- # Source Nodes: [out], Original ATen: [aten.convolution]
- triton_poi_fused_convolution_7[grid(4096, 9)](arg3_1, buf14, 4096, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg3_1
- # Source Nodes: [out], Original ATen: [aten.convolution]
- buf15 = extern_kernels.convolution(buf13, buf14, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf15, (1, 64, 57, 57), (207936, 1, 3648, 64))
- buf16 = buf9; del buf9 # reuse
- buf17 = buf10; del buf10 # reuse
- # Source Nodes: [out_1], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_red_fused__native_batch_norm_legit_functional_8[grid(64)](buf15, arg65_1, arg66_1, buf16, buf17, arg65_1, arg66_1, 64, 3249, XBLOCK=1, RBLOCK=2048, num_warps=8, num_stages=1)
- del arg65_1
- del arg66_1
- buf19 = empty_strided_cuda((1, 64, 57, 57), (207936, 1, 3648, 64), torch.float32)
- # Source Nodes: [out_1, out_2], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_relu_9[grid(207936)](buf15, buf16, buf17, arg4_1, arg5_1, buf19, 207936, XBLOCK=1024, num_warps=4, num_stages=1)
- del arg4_1
- del arg5_1
- del buf15
- buf20 = buf14; del buf14 # reuse
- # Source Nodes: [out_1, out_2, out_3], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- triton_poi_fused_convolution_7[grid(4096, 9)](arg6_1, buf20, 4096, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg6_1
- # Source Nodes: [out_1, out_2, out_3], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- buf21 = extern_kernels.convolution(buf19, buf20, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf21, (1, 64, 57, 57), (207936, 1, 3648, 64))
- del buf19
- buf22 = reinterpret_tensor(buf17, (1, 64, 1, 1), (64, 1, 1, 1), 0); del buf17 # reuse
- buf23 = reinterpret_tensor(buf16, (1, 64, 1, 1), (64, 1, 64, 64), 0); del buf16 # reuse
- # Source Nodes: [out_4], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_red_fused__native_batch_norm_legit_functional_8[grid(64)](buf21, arg68_1, arg69_1, buf22, buf23, arg68_1, arg69_1, 64, 3249, XBLOCK=1, RBLOCK=2048, num_warps=8, num_stages=1)
- del arg68_1
- del arg69_1
- buf25 = buf13; del buf13 # reuse
- # Source Nodes: [out_4, out_5, out_6], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_add_relu_10[grid(207936)](buf25, buf21, buf22, buf23, arg7_1, arg8_1, 207936, XBLOCK=512, num_warps=8, num_stages=1)
- del arg7_1
- del arg8_1
- buf26 = buf20; del buf20 # reuse
- # Source Nodes: [out_7], Original ATen: [aten.convolution]
- triton_poi_fused_convolution_7[grid(4096, 9)](arg9_1, buf26, 4096, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg9_1
- # Source Nodes: [out_7], Original ATen: [aten.convolution]
- buf27 = extern_kernels.convolution(buf25, buf26, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf27, (1, 64, 57, 57), (207936, 1, 3648, 64))
- buf28 = reinterpret_tensor(buf23, (1, 64, 1, 1), (64, 1, 1, 1), 0); del buf23 # reuse
- buf29 = reinterpret_tensor(buf22, (1, 64, 1, 1), (64, 1, 64, 64), 0); del buf22 # reuse
- # Source Nodes: [out_8], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_red_fused__native_batch_norm_legit_functional_8[grid(64)](buf27, arg71_1, arg72_1, buf28, buf29, arg71_1, arg72_1, 64, 3249, XBLOCK=1, RBLOCK=2048, num_warps=8, num_stages=1)
- del arg71_1
- del arg72_1
- buf31 = buf21; del buf21 # reuse
- # Source Nodes: [out_8, out_9], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_relu_9[grid(207936)](buf27, buf28, buf29, arg10_1, arg11_1, buf31, 207936, XBLOCK=1024, num_warps=4, num_stages=1)
- del arg10_1
- del arg11_1
- del buf27
- buf32 = buf26; del buf26 # reuse
- # Source Nodes: [out_10, out_8, out_9], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- triton_poi_fused_convolution_7[grid(4096, 9)](arg12_1, buf32, 4096, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg12_1
- # Source Nodes: [out_10, out_8, out_9], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- buf33 = extern_kernels.convolution(buf31, buf32, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf33, (1, 64, 57, 57), (207936, 1, 3648, 64))
- del buf31
- del buf32
- buf34 = reinterpret_tensor(buf29, (1, 64, 1, 1), (64, 1, 1, 1), 0); del buf29 # reuse
- buf35 = reinterpret_tensor(buf28, (1, 64, 1, 1), (64, 1, 64, 64), 0); del buf28 # reuse
- # Source Nodes: [out_11], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_red_fused__native_batch_norm_legit_functional_8[grid(64)](buf33, arg74_1, arg75_1, buf34, buf35, arg74_1, arg75_1, 64, 3249, XBLOCK=1, RBLOCK=2048, num_warps=8, num_stages=1)
- del arg74_1
- del arg75_1
- buf37 = buf25; del buf25 # reuse
- # Source Nodes: [out_11, out_12, out_13], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_add_relu_10[grid(207936)](buf37, buf33, buf34, buf35, arg13_1, arg14_1, 207936, XBLOCK=512, num_warps=8, num_stages=1)
- del arg13_1
- del arg14_1
- del buf33
- del buf34
- del buf35
- buf38 = empty_strided_cuda((128, 64, 3, 3), (576, 1, 192, 64), torch.float32)
- # Source Nodes: [out_14], Original ATen: [aten.convolution]
- triton_poi_fused_convolution_11[grid(8192, 9)](arg15_1, buf38, 8192, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg15_1
- # Source Nodes: [out_14], Original ATen: [aten.convolution]
- buf39 = extern_kernels.convolution(buf37, buf38, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf39, (1, 128, 29, 29), (107648, 1, 3712, 128))
- del buf38
- buf40 = reinterpret_tensor(buf8, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf8 # reuse
- buf41 = reinterpret_tensor(buf7, (1, 128, 1, 1), (128, 1, 128, 128), 0); del buf7 # reuse
- # Source Nodes: [out_15], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_12[grid(128)](buf39, arg77_1, arg78_1, buf40, buf41, arg77_1, arg78_1, 128, 841, num_warps=8, num_stages=1)
- del arg77_1
- del arg78_1
- buf43 = empty_strided_cuda((1, 128, 29, 29), (107648, 1, 3712, 128), torch.float32)
- # Source Nodes: [out_15, out_16], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_relu_13[grid(107648)](buf39, buf40, buf41, arg16_1, arg17_1, buf43, 107648, XBLOCK=512, num_warps=8, num_stages=1)
- del arg16_1
- del arg17_1
- del buf39
- buf44 = empty_strided_cuda((128, 128, 3, 3), (1152, 1, 384, 128), torch.float32)
- # Source Nodes: [out_15, out_16, out_17], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_14[grid(16384, 9)](arg18_1, buf44, 16384, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg18_1
- # Source Nodes: [out_15, out_16, out_17], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- buf45 = extern_kernels.convolution(buf43, buf44, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf45, (1, 128, 29, 29), (107648, 1, 3712, 128))
- buf46 = reinterpret_tensor(buf41, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf41 # reuse
- buf47 = reinterpret_tensor(buf40, (1, 128, 1, 1), (128, 1, 128, 128), 0); del buf40 # reuse
- # Source Nodes: [out_18], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_12[grid(128)](buf45, arg80_1, arg81_1, buf46, buf47, arg80_1, arg81_1, 128, 841, num_warps=8, num_stages=1)
- del arg80_1
- del arg81_1
- # Source Nodes: [getattr_l__self___layer2___0___downsample_0], Original ATen: [aten.convolution]
- buf49 = extern_kernels.convolution(buf37, arg21_1, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf49, (1, 128, 29, 29), (107648, 1, 3712, 128))
- del arg21_1
- del buf37
- buf50 = reinterpret_tensor(buf6, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf6 # reuse
- buf51 = empty_strided_cuda((1, 128, 1, 1), (128, 1, 128, 128), torch.float32)
- # Source Nodes: [identity], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_12[grid(128)](buf49, arg83_1, arg84_1, buf50, buf51, arg83_1, arg84_1, 128, 841, num_warps=8, num_stages=1)
- del arg83_1
- del arg84_1
- buf53 = buf43; del buf43 # reuse
- buf54 = buf53; del buf53 # reuse
- # Source Nodes: [identity, out_18, out_19, out_20], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_add_relu_15[grid(107648)](buf54, buf45, buf46, buf47, arg19_1, arg20_1, buf49, buf50, buf51, arg22_1, arg23_1, 107648, XBLOCK=512, num_warps=8, num_stages=1)
- del arg19_1
- del arg20_1
- del arg22_1
- del arg23_1
- del buf45
- del buf46
- del buf47
- buf55 = buf44; del buf44 # reuse
- # Source Nodes: [out_20, out_21], Original ATen: [aten.convolution, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_14[grid(16384, 9)](arg24_1, buf55, 16384, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg24_1
- # Source Nodes: [out_20, out_21], Original ATen: [aten.convolution, aten.relu]
- buf56 = extern_kernels.convolution(buf54, buf55, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf56, (1, 128, 29, 29), (107648, 1, 3712, 128))
- buf57 = reinterpret_tensor(buf51, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf51 # reuse
- buf58 = reinterpret_tensor(buf50, (1, 128, 1, 1), (128, 1, 128, 128), 0); del buf50 # reuse
- # Source Nodes: [out_22], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_12[grid(128)](buf56, arg86_1, arg87_1, buf57, buf58, arg86_1, arg87_1, 128, 841, num_warps=8, num_stages=1)
- del arg86_1
- del arg87_1
- buf60 = buf49; del buf49 # reuse
- # Source Nodes: [out_22, out_23], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_relu_13[grid(107648)](buf56, buf57, buf58, arg25_1, arg26_1, buf60, 107648, XBLOCK=512, num_warps=8, num_stages=1)
- del arg25_1
- del arg26_1
- del buf56
- buf61 = buf55; del buf55 # reuse
- # Source Nodes: [out_22, out_23, out_24], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_14[grid(16384, 9)](arg27_1, buf61, 16384, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg27_1
- # Source Nodes: [out_22, out_23, out_24], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- buf62 = extern_kernels.convolution(buf60, buf61, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf62, (1, 128, 29, 29), (107648, 1, 3712, 128))
- del buf60
- del buf61
- buf63 = reinterpret_tensor(buf58, (1, 128, 1, 1), (128, 1, 1, 1), 0); del buf58 # reuse
- buf64 = reinterpret_tensor(buf57, (1, 128, 1, 1), (128, 1, 128, 128), 0); del buf57 # reuse
- # Source Nodes: [out_25], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_12[grid(128)](buf62, arg89_1, arg90_1, buf63, buf64, arg89_1, arg90_1, 128, 841, num_warps=8, num_stages=1)
- del arg89_1
- del arg90_1
- buf66 = buf54; del buf54 # reuse
- # Source Nodes: [out_25, out_26, out_27], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_add_relu_16[grid(107648)](buf66, buf62, buf63, buf64, arg28_1, arg29_1, 107648, XBLOCK=512, num_warps=8, num_stages=1)
- del arg28_1
- del arg29_1
- del buf62
- del buf63
- del buf64
- buf67 = empty_strided_cuda((256, 128, 3, 3), (1152, 1, 384, 128), torch.float32)
- # Source Nodes: [out_28], Original ATen: [aten.convolution]
- triton_poi_fused_convolution_17[grid(32768, 9)](arg30_1, buf67, 32768, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg30_1
- # Source Nodes: [out_28], Original ATen: [aten.convolution]
- buf68 = extern_kernels.convolution(buf66, buf67, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf68, (1, 256, 15, 15), (57600, 1, 3840, 256))
- del buf67
- buf69 = empty_strided_cuda((1, 256, 1, 1), (256, 1, 1, 1), torch.float32)
- buf70 = empty_strided_cuda((1, 256, 1, 1), (256, 1, 256, 256), torch.float32)
- # Source Nodes: [out_29], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_18[grid(256)](buf68, arg92_1, arg93_1, buf69, buf70, arg92_1, arg93_1, 256, 225, XBLOCK=1, num_warps=2, num_stages=1)
- del arg92_1
- del arg93_1
- buf72 = empty_strided_cuda((1, 256, 15, 15), (57600, 1, 3840, 256), torch.float32)
- # Source Nodes: [out_29, out_30], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_relu_19[grid(57600)](buf68, buf69, buf70, arg31_1, arg32_1, buf72, 57600, XBLOCK=512, num_warps=4, num_stages=1)
- del arg31_1
- del arg32_1
- del buf68
- buf73 = empty_strided_cuda((256, 256, 3, 3), (2304, 1, 768, 256), torch.float32)
- # Source Nodes: [out_29, out_30, out_31], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_20[grid(65536, 9)](arg33_1, buf73, 65536, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg33_1
- # Source Nodes: [out_29, out_30, out_31], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- buf74 = extern_kernels.convolution(buf72, buf73, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf74, (1, 256, 15, 15), (57600, 1, 3840, 256))
- buf75 = reinterpret_tensor(buf70, (1, 256, 1, 1), (256, 1, 1, 1), 0); del buf70 # reuse
- buf76 = reinterpret_tensor(buf69, (1, 256, 1, 1), (256, 1, 256, 256), 0); del buf69 # reuse
- # Source Nodes: [out_32], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_18[grid(256)](buf74, arg95_1, arg96_1, buf75, buf76, arg95_1, arg96_1, 256, 225, XBLOCK=1, num_warps=2, num_stages=1)
- del arg95_1
- del arg96_1
- # Source Nodes: [getattr_l__self___layer3___0___downsample_0], Original ATen: [aten.convolution]
- buf78 = extern_kernels.convolution(buf66, arg36_1, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf78, (1, 256, 15, 15), (57600, 1, 3840, 256))
- del arg36_1
- del buf66
- buf79 = empty_strided_cuda((1, 256, 1, 1), (256, 1, 1, 1), torch.float32)
- buf80 = empty_strided_cuda((1, 256, 1, 1), (256, 1, 256, 256), torch.float32)
- # Source Nodes: [identity_1], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_18[grid(256)](buf78, arg98_1, arg99_1, buf79, buf80, arg98_1, arg99_1, 256, 225, XBLOCK=1, num_warps=2, num_stages=1)
- del arg98_1
- del arg99_1
- buf82 = buf72; del buf72 # reuse
- buf83 = buf82; del buf82 # reuse
- # Source Nodes: [identity_1, out_32, out_33, out_34], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_add_relu_21[grid(57600)](buf83, buf74, buf75, buf76, arg34_1, arg35_1, buf78, buf79, buf80, arg37_1, arg38_1, 57600, XBLOCK=256, num_warps=4, num_stages=1)
- del arg34_1
- del arg35_1
- del arg37_1
- del arg38_1
- del buf74
- del buf75
- del buf76
- buf84 = buf73; del buf73 # reuse
- # Source Nodes: [out_34, out_35], Original ATen: [aten.convolution, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_20[grid(65536, 9)](arg39_1, buf84, 65536, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg39_1
- # Source Nodes: [out_34, out_35], Original ATen: [aten.convolution, aten.relu]
- buf85 = extern_kernels.convolution(buf83, buf84, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf85, (1, 256, 15, 15), (57600, 1, 3840, 256))
- buf86 = reinterpret_tensor(buf80, (1, 256, 1, 1), (256, 1, 1, 1), 0); del buf80 # reuse
- buf87 = reinterpret_tensor(buf79, (1, 256, 1, 1), (256, 1, 256, 256), 0); del buf79 # reuse
- # Source Nodes: [out_36], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_18[grid(256)](buf85, arg101_1, arg102_1, buf86, buf87, arg101_1, arg102_1, 256, 225, XBLOCK=1, num_warps=2, num_stages=1)
- del arg101_1
- del arg102_1
- buf89 = buf78; del buf78 # reuse
- # Source Nodes: [out_36, out_37], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_relu_19[grid(57600)](buf85, buf86, buf87, arg40_1, arg41_1, buf89, 57600, XBLOCK=512, num_warps=4, num_stages=1)
- del arg40_1
- del arg41_1
- del buf85
- buf90 = buf84; del buf84 # reuse
- # Source Nodes: [out_36, out_37, out_38], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_20[grid(65536, 9)](arg42_1, buf90, 65536, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg42_1
- # Source Nodes: [out_36, out_37, out_38], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- buf91 = extern_kernels.convolution(buf89, buf90, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf91, (1, 256, 15, 15), (57600, 1, 3840, 256))
- del buf89
- del buf90
- buf92 = reinterpret_tensor(buf87, (1, 256, 1, 1), (256, 1, 1, 1), 0); del buf87 # reuse
- buf93 = reinterpret_tensor(buf86, (1, 256, 1, 1), (256, 1, 256, 256), 0); del buf86 # reuse
- # Source Nodes: [out_39], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_18[grid(256)](buf91, arg104_1, arg105_1, buf92, buf93, arg104_1, arg105_1, 256, 225, XBLOCK=1, num_warps=2, num_stages=1)
- del arg104_1
- del arg105_1
- buf95 = buf83; del buf83 # reuse
- # Source Nodes: [out_39, out_40, out_41], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_add_relu_22[grid(57600)](buf95, buf91, buf92, buf93, arg43_1, arg44_1, 57600, XBLOCK=256, num_warps=4, num_stages=1)
- del arg43_1
- del arg44_1
- del buf91
- del buf92
- del buf93
- buf96 = empty_strided_cuda((512, 256, 3, 3), (2304, 1, 768, 256), torch.float32)
- # Source Nodes: [out_42], Original ATen: [aten.convolution]
- triton_poi_fused_convolution_23[grid(131072, 9)](arg45_1, buf96, 131072, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg45_1
- # Source Nodes: [out_42], Original ATen: [aten.convolution]
- buf97 = extern_kernels.convolution(buf95, buf96, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf97, (1, 512, 8, 8), (32768, 1, 4096, 512))
- del buf96
- buf98 = empty_strided_cuda((1, 512, 1, 1), (512, 1, 1, 1), torch.float32)
- buf99 = empty_strided_cuda((1, 512, 1, 1), (512, 1, 512, 512), torch.float32)
- # Source Nodes: [out_43], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_24[grid(512)](buf97, arg107_1, arg108_1, buf98, buf99, arg107_1, arg108_1, 512, 64, XBLOCK=32, num_warps=8, num_stages=1)
- del arg107_1
- del arg108_1
- buf101 = empty_strided_cuda((1, 512, 8, 8), (32768, 1, 4096, 512), torch.float32)
- # Source Nodes: [out_43, out_44], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_relu_25[grid(32768)](buf97, buf98, buf99, arg46_1, arg47_1, buf101, 32768, XBLOCK=128, num_warps=4, num_stages=1)
- del arg46_1
- del arg47_1
- del buf97
- buf102 = empty_strided_cuda((512, 512, 3, 3), (4608, 1, 1536, 512), torch.float32)
- # Source Nodes: [out_43, out_44, out_45], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_26[grid(262144, 9)](arg48_1, buf102, 262144, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg48_1
- # Source Nodes: [out_43, out_44, out_45], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- buf103 = extern_kernels.convolution(buf101, buf102, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf103, (1, 512, 8, 8), (32768, 1, 4096, 512))
- buf104 = reinterpret_tensor(buf99, (1, 512, 1, 1), (512, 1, 1, 1), 0); del buf99 # reuse
- buf105 = reinterpret_tensor(buf98, (1, 512, 1, 1), (512, 1, 512, 512), 0); del buf98 # reuse
- # Source Nodes: [out_46], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_24[grid(512)](buf103, arg110_1, arg111_1, buf104, buf105, arg110_1, arg111_1, 512, 64, XBLOCK=32, num_warps=8, num_stages=1)
- del arg110_1
- del arg111_1
- # Source Nodes: [getattr_l__self___layer4___0___downsample_0], Original ATen: [aten.convolution]
- buf107 = extern_kernels.convolution(buf95, arg51_1, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf107, (1, 512, 8, 8), (32768, 1, 4096, 512))
- del arg51_1
- del buf95
- buf108 = empty_strided_cuda((1, 512, 1, 1), (512, 1, 1, 1), torch.float32)
- buf109 = empty_strided_cuda((1, 512, 1, 1), (512, 1, 512, 512), torch.float32)
- # Source Nodes: [identity_2], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_24[grid(512)](buf107, arg113_1, arg114_1, buf108, buf109, arg113_1, arg114_1, 512, 64, XBLOCK=32, num_warps=8, num_stages=1)
- del arg113_1
- del arg114_1
- buf111 = buf101; del buf101 # reuse
- buf112 = buf111; del buf111 # reuse
- # Source Nodes: [identity_2, out_46, out_47, out_48], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_add_relu_27[grid(32768)](buf112, buf103, buf104, buf105, arg49_1, arg50_1, buf107, buf108, buf109, arg52_1, arg53_1, 32768, XBLOCK=256, num_warps=4, num_stages=1)
- del arg49_1
- del arg50_1
- del arg52_1
- del arg53_1
- del buf103
- del buf104
- del buf105
- buf113 = buf102; del buf102 # reuse
- # Source Nodes: [out_48, out_49], Original ATen: [aten.convolution, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_26[grid(262144, 9)](arg54_1, buf113, 262144, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg54_1
- # Source Nodes: [out_48, out_49], Original ATen: [aten.convolution, aten.relu]
- buf114 = extern_kernels.convolution(buf112, buf113, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf114, (1, 512, 8, 8), (32768, 1, 4096, 512))
- buf115 = reinterpret_tensor(buf109, (1, 512, 1, 1), (512, 1, 1, 1), 0); del buf109 # reuse
- buf116 = reinterpret_tensor(buf108, (1, 512, 1, 1), (512, 1, 512, 512), 0); del buf108 # reuse
- # Source Nodes: [out_50], Original ATen: [aten._native_batch_norm_legit_functional]
- triton_per_fused__native_batch_norm_legit_functional_24[grid(512)](buf114, arg116_1, arg117_1, buf115, buf116, arg116_1, arg117_1, 512, 64, XBLOCK=32, num_warps=8, num_stages=1)
- del arg116_1
- del arg117_1
- buf118 = buf107; del buf107 # reuse
- # Source Nodes: [out_50, out_51], Original ATen: [aten._native_batch_norm_legit_functional, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_relu_25[grid(32768)](buf114, buf115, buf116, arg55_1, arg56_1, buf118, 32768, XBLOCK=128, num_warps=4, num_stages=1)
- del arg55_1
- del arg56_1
- del buf114
- del buf115
- buf119 = buf113; del buf113 # reuse
- # Source Nodes: [out_50, out_51, out_52], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- triton_poi_fused__native_batch_norm_legit_functional_convolution_relu_26[grid(262144, 9)](arg57_1, buf119, 262144, 9, XBLOCK=16, YBLOCK=64, num_warps=4, num_stages=1)
- del arg57_1
- # Source Nodes: [out_50, out_51, out_52], Original ATen: [aten._native_batch_norm_legit_functional, aten.convolution, aten.relu]
- buf120 = extern_kernels.convolution(buf118, buf119, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
- assert_size_stride(buf120, (1, 512, 8, 8), (32768, 1, 4096, 512))
- del buf118
- del buf119
- buf124 = buf116; del buf116 # reuse
- buf125 = buf124; del buf124 # reuse
- # Source Nodes: [out_53, out_54, out_55, x_4], Original ATen: [aten._native_batch_norm_legit_functional, aten.add, aten.mean, aten.relu]
- triton_per_fused__native_batch_norm_legit_functional_add_mean_relu_28[grid(512)](buf125, buf120, arg58_1, arg59_1, buf112, arg119_1, arg120_1, arg119_1, arg120_1, 512, 64, XBLOCK=32, num_warps=8, num_stages=1)
- del arg119_1
- del arg120_1
- del arg58_1
- del arg59_1
- del buf112
- del buf120
- buf126 = empty_strided_cuda((1, 1000), (1000, 1), torch.float32)
- # Source Nodes: [x_6], Original ATen: [aten.addmm]
- extern_kernels.addmm(arg61_1, reinterpret_tensor(buf125, (1, 512), (0, 1), 0), reinterpret_tensor(arg60_1, (512, 1000), (1, 512), 0), alpha=1, beta=1, out=buf126)
- del arg60_1
- del arg61_1
- del buf125
- # Source Nodes: [x_1], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg64_1, arg64_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg64_1
- # Source Nodes: [out_1], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg67_1, arg67_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg67_1
- # Source Nodes: [out_4], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg70_1, arg70_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg70_1
- # Source Nodes: [out_8], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg73_1, arg73_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg73_1
- # Source Nodes: [out_11], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg76_1, arg76_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg76_1
- # Source Nodes: [out_15], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg79_1, arg79_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg79_1
- # Source Nodes: [out_18], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg82_1, arg82_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg82_1
- # Source Nodes: [identity], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg85_1, arg85_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg85_1
- # Source Nodes: [out_22], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg88_1, arg88_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg88_1
- # Source Nodes: [out_25], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg91_1, arg91_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg91_1
- # Source Nodes: [out_29], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg94_1, arg94_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg94_1
- # Source Nodes: [out_32], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg97_1, arg97_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg97_1
- # Source Nodes: [identity_1], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg100_1, arg100_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg100_1
- # Source Nodes: [out_36], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg103_1, arg103_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg103_1
- # Source Nodes: [out_39], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg106_1, arg106_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg106_1
- # Source Nodes: [out_43], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg109_1, arg109_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg109_1
- # Source Nodes: [out_46], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg112_1, arg112_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg112_1
- # Source Nodes: [identity_2], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg115_1, arg115_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg115_1
- # Source Nodes: [out_50], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg118_1, arg118_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg118_1
- # Source Nodes: [out_53], Original ATen: [aten.add]
- triton_poi_fused_add_29[grid(1)](arg121_1, arg121_1, 1, XBLOCK=1, num_warps=1, num_stages=1)
- del arg121_1
- return (buf126, )
- def benchmark_compiled_module(times=10, repeat=10):
- from torch._dynamo.testing import rand_strided
- from torch._inductor.utils import print_performance
- arg0_1 = rand_strided((64, 3, 7, 7), (147, 49, 7, 1), device='cuda:0', dtype=torch.float32)
- arg1_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg2_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg3_1 = rand_strided((64, 64, 3, 3), (576, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg4_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg5_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg6_1 = rand_strided((64, 64, 3, 3), (576, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg7_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg8_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg9_1 = rand_strided((64, 64, 3, 3), (576, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg10_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg11_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg12_1 = rand_strided((64, 64, 3, 3), (576, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg13_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg14_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg15_1 = rand_strided((128, 64, 3, 3), (576, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg16_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg17_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg18_1 = rand_strided((128, 128, 3, 3), (1152, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg19_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg20_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg21_1 = rand_strided((128, 64, 1, 1), (64, 1, 1, 1), device='cuda:0', dtype=torch.float32)
- arg22_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg23_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg24_1 = rand_strided((128, 128, 3, 3), (1152, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg25_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg26_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg27_1 = rand_strided((128, 128, 3, 3), (1152, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg28_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg29_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg30_1 = rand_strided((256, 128, 3, 3), (1152, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg31_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg32_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg33_1 = rand_strided((256, 256, 3, 3), (2304, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg34_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg35_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg36_1 = rand_strided((256, 128, 1, 1), (128, 1, 1, 1), device='cuda:0', dtype=torch.float32)
- arg37_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg38_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg39_1 = rand_strided((256, 256, 3, 3), (2304, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg40_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg41_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg42_1 = rand_strided((256, 256, 3, 3), (2304, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg43_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg44_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg45_1 = rand_strided((512, 256, 3, 3), (2304, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg46_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg47_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg48_1 = rand_strided((512, 512, 3, 3), (4608, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg49_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg50_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg51_1 = rand_strided((512, 256, 1, 1), (256, 1, 1, 1), device='cuda:0', dtype=torch.float32)
- arg52_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg53_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg54_1 = rand_strided((512, 512, 3, 3), (4608, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg55_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg56_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg57_1 = rand_strided((512, 512, 3, 3), (4608, 9, 3, 1), device='cuda:0', dtype=torch.float32)
- arg58_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg59_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg60_1 = rand_strided((1000, 512), (512, 1), device='cuda:0', dtype=torch.float32)
- arg61_1 = rand_strided((1000, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg62_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg63_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg64_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg65_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg66_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg67_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg68_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg69_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg70_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg71_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg72_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg73_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg74_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg75_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg76_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg77_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg78_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg79_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg80_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg81_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg82_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg83_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg84_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg85_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg86_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg87_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg88_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg89_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg90_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg91_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg92_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg93_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg94_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg95_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg96_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg97_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg98_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg99_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg100_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg101_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg102_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg103_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg104_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg105_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg106_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg107_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg108_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg109_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg110_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg111_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg112_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg113_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg114_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg115_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg116_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg117_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg118_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg119_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg120_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
- arg121_1 = rand_strided((), (), device='cuda:0', dtype=torch.int64)
- arg122_1 = rand_strided((1, 3, 228, 228), (155952, 51984, 228, 1), device='cuda:0', dtype=torch.float32)
- fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1])
- return print_performance(fn, times=times, repeat=repeat)
- if __name__ == "__main__":
- from torch._inductor.wrapper_benchmark import compiled_module_main
- compiled_module_main('None', benchmark_compiled_module)
Advertisement
Add Comment
Please, Sign In to add comment