Guest User

jit_jaxpr

a guest
Aug 1st, 2022
19
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.03 KB | None | 0 0
  1. { lambda ; a:f32[3,3] b:f32[3,3]. let
  2. c:f32[3,3] d:f32[3,3] = xla_call[
  3. call_jaxpr={ lambda ; e:f32[3,3] f:f32[3,3]. let
  4. g:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
  5. h:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
  6. i:i32[2] = concatenate[dimension=0] g h
  7. j:f32[2,2] = gather[
  8. dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
  9. fill_value=None
  10. indices_are_sorted=True
  11. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  12. slice_sizes=(2, 2)
  13. unique_indices=True
  14. ] e i
  15. k:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
  16. l:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
  17. m:i32[2] = concatenate[dimension=0] k l
  18. n:f32[2,2] = gather[
  19. dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
  20. fill_value=None
  21. indices_are_sorted=True
  22. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  23. slice_sizes=(2, 2)
  24. unique_indices=True
  25. ] e m
  26. o:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
  27. p:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
  28. q:i32[2] = concatenate[dimension=0] o p
  29. r:f32[2,2] = gather[
  30. dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
  31. fill_value=None
  32. indices_are_sorted=True
  33. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  34. slice_sizes=(2, 2)
  35. unique_indices=True
  36. ] e q
  37. s:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
  38. t:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
  39. u:i32[2] = concatenate[dimension=0] s t
  40. v:f32[2,2] = gather[
  41. dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
  42. fill_value=None
  43. indices_are_sorted=True
  44. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  45. slice_sizes=(2, 2)
  46. unique_indices=True
  47. ] e u
  48. w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
  49. x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
  50. y:i32[2] = concatenate[dimension=0] w x
  51. z:f32[2,2] = gather[
  52. dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
  53. fill_value=None
  54. indices_are_sorted=True
  55. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  56. slice_sizes=(2, 2)
  57. unique_indices=True
  58. ] f y
  59. ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
  60. bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
  61. bc:i32[2] = concatenate[dimension=0] ba bb
  62. bd:f32[2,2] = gather[
  63. dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
  64. fill_value=None
  65. indices_are_sorted=True
  66. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  67. slice_sizes=(2, 2)
  68. unique_indices=True
  69. ] f bc
  70. be:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
  71. bf:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
  72. bg:i32[2] = concatenate[dimension=0] be bf
  73. bh:f32[2,2] = gather[
  74. dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
  75. fill_value=None
  76. indices_are_sorted=True
  77. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  78. slice_sizes=(2, 2)
  79. unique_indices=True
  80. ] f bg
  81. bi:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
  82. bj:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
  83. bk:i32[2] = concatenate[dimension=0] bi bj
  84. bl:f32[2,2] = gather[
  85. dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
  86. fill_value=None
  87. indices_are_sorted=True
  88. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  89. slice_sizes=(2, 2)
  90. unique_indices=True
  91. ] f bk
  92. bm:f32[2,2] = add j 0.0
  93. bn:f32[2,2] = add bm n
  94. bo:f32[2,2] = add bn r
  95. bp:f32[2,2] = add bo v
  96. bq:f32[2,2] = add bp z
  97. br:f32[2,2] = add bq bd
  98. bs:f32[2,2] = add br bh
  99. bt:f32[2,2] = add bs bl
  100. _:f32[] = reduce_sum[axes=(0, 1)] bt
  101. bu:f32[2,2] = broadcast_in_dim[broadcast_dimensions=() shape=(2, 2)] 1.0
  102. bv:f32[2,2] = reduce_sum[axes=()] bu
  103. bw:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
  104. bx:f32[3,3] = scatter-add[
  105. dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
  106. indices_are_sorted=True
  107. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  108. unique_indices=True
  109. update_consts=()
  110. update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
  111. ca:f32[] = add by bz
  112. in (ca,) }
  113. ] bw bk bv
  114. cb:f32[2,2] = reduce_sum[axes=()] bu
  115. cc:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
  116. cd:f32[3,3] = scatter-add[
  117. dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
  118. indices_are_sorted=True
  119. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  120. unique_indices=True
  121. update_consts=()
  122. update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
  123. ca:f32[] = add by bz
  124. in (ca,) }
  125. ] cc bg cb
  126. ce:f32[3,3] = add_any bx cd
  127. cf:f32[2,2] = reduce_sum[axes=()] bu
  128. cg:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
  129. ch:f32[3,3] = scatter-add[
  130. dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
  131. indices_are_sorted=True
  132. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  133. unique_indices=True
  134. update_consts=()
  135. update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
  136. ca:f32[] = add by bz
  137. in (ca,) }
  138. ] cg bc cf
  139. ci:f32[3,3] = add_any ce ch
  140. cj:f32[2,2] = reduce_sum[axes=()] bu
  141. ck:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
  142. cl:f32[3,3] = scatter-add[
  143. dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
  144. indices_are_sorted=True
  145. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  146. unique_indices=True
  147. update_consts=()
  148. update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
  149. ca:f32[] = add by bz
  150. in (ca,) }
  151. ] ck y cj
  152. cm:f32[3,3] = add_any ci cl
  153. cn:f32[2,2] = reduce_sum[axes=()] bu
  154. co:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
  155. cp:f32[3,3] = scatter-add[
  156. dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
  157. indices_are_sorted=True
  158. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  159. unique_indices=True
  160. update_consts=()
  161. update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
  162. ca:f32[] = add by bz
  163. in (ca,) }
  164. ] co u cn
  165. cq:f32[2,2] = reduce_sum[axes=()] bu
  166. cr:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
  167. cs:f32[3,3] = scatter-add[
  168. dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
  169. indices_are_sorted=True
  170. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  171. unique_indices=True
  172. update_consts=()
  173. update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
  174. ca:f32[] = add by bz
  175. in (ca,) }
  176. ] cr q cq
  177. ct:f32[3,3] = add_any cp cs
  178. cu:f32[2,2] = reduce_sum[axes=()] bu
  179. cv:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
  180. cw:f32[3,3] = scatter-add[
  181. dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
  182. indices_are_sorted=True
  183. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  184. unique_indices=True
  185. update_consts=()
  186. update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
  187. ca:f32[] = add by bz
  188. in (ca,) }
  189. ] cv m cu
  190. cx:f32[3,3] = add_any ct cw
  191. cy:f32[2,2] = reduce_sum[axes=()] bu
  192. cz:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
  193. da:f32[3,3] = scatter-add[
  194. dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
  195. indices_are_sorted=True
  196. mode=GatherScatterMode.PROMISE_IN_BOUNDS
  197. unique_indices=True
  198. update_consts=()
  199. update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
  200. ca:f32[] = add by bz
  201. in (ca,) }
  202. ] cz i cy
  203. db:f32[3,3] = add_any cx da
  204. in (db, cm) }
  205. name=total_energy
  206. ] a b
  207. in (c, d) }
Advertisement
Add Comment
Please, Sign In to add comment