Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- { lambda ; a:f32[3,3] b:f32[3,3]. let
- c:f32[3,3] d:f32[3,3] = xla_call[
- call_jaxpr={ lambda ; e:f32[3,3] f:f32[3,3]. let
- g:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
- h:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
- i:i32[2] = concatenate[dimension=0] g h
- j:f32[2,2] = gather[
- dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
- fill_value=None
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- slice_sizes=(2, 2)
- unique_indices=True
- ] e i
- k:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
- l:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
- m:i32[2] = concatenate[dimension=0] k l
- n:f32[2,2] = gather[
- dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
- fill_value=None
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- slice_sizes=(2, 2)
- unique_indices=True
- ] e m
- o:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
- p:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
- q:i32[2] = concatenate[dimension=0] o p
- r:f32[2,2] = gather[
- dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
- fill_value=None
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- slice_sizes=(2, 2)
- unique_indices=True
- ] e q
- s:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
- t:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
- u:i32[2] = concatenate[dimension=0] s t
- v:f32[2,2] = gather[
- dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
- fill_value=None
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- slice_sizes=(2, 2)
- unique_indices=True
- ] e u
- w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
- x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
- y:i32[2] = concatenate[dimension=0] w x
- z:f32[2,2] = gather[
- dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
- fill_value=None
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- slice_sizes=(2, 2)
- unique_indices=True
- ] f y
- ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
- bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
- bc:i32[2] = concatenate[dimension=0] ba bb
- bd:f32[2,2] = gather[
- dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
- fill_value=None
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- slice_sizes=(2, 2)
- unique_indices=True
- ] f bc
- be:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
- bf:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
- bg:i32[2] = concatenate[dimension=0] be bf
- bh:f32[2,2] = gather[
- dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
- fill_value=None
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- slice_sizes=(2, 2)
- unique_indices=True
- ] f bg
- bi:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
- bj:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
- bk:i32[2] = concatenate[dimension=0] bi bj
- bl:f32[2,2] = gather[
- dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(), start_index_map=(0, 1))
- fill_value=None
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- slice_sizes=(2, 2)
- unique_indices=True
- ] f bk
- bm:f32[2,2] = add j 0.0
- bn:f32[2,2] = add bm n
- bo:f32[2,2] = add bn r
- bp:f32[2,2] = add bo v
- bq:f32[2,2] = add bp z
- br:f32[2,2] = add bq bd
- bs:f32[2,2] = add br bh
- bt:f32[2,2] = add bs bl
- _:f32[] = reduce_sum[axes=(0, 1)] bt
- bu:f32[2,2] = broadcast_in_dim[broadcast_dimensions=() shape=(2, 2)] 1.0
- bv:f32[2,2] = reduce_sum[axes=()] bu
- bw:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
- bx:f32[3,3] = scatter-add[
- dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- unique_indices=True
- update_consts=()
- update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
- ca:f32[] = add by bz
- in (ca,) }
- ] bw bk bv
- cb:f32[2,2] = reduce_sum[axes=()] bu
- cc:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
- cd:f32[3,3] = scatter-add[
- dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- unique_indices=True
- update_consts=()
- update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
- ca:f32[] = add by bz
- in (ca,) }
- ] cc bg cb
- ce:f32[3,3] = add_any bx cd
- cf:f32[2,2] = reduce_sum[axes=()] bu
- cg:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
- ch:f32[3,3] = scatter-add[
- dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- unique_indices=True
- update_consts=()
- update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
- ca:f32[] = add by bz
- in (ca,) }
- ] cg bc cf
- ci:f32[3,3] = add_any ce ch
- cj:f32[2,2] = reduce_sum[axes=()] bu
- ck:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
- cl:f32[3,3] = scatter-add[
- dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- unique_indices=True
- update_consts=()
- update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
- ca:f32[] = add by bz
- in (ca,) }
- ] ck y cj
- cm:f32[3,3] = add_any ci cl
- cn:f32[2,2] = reduce_sum[axes=()] bu
- co:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
- cp:f32[3,3] = scatter-add[
- dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- unique_indices=True
- update_consts=()
- update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
- ca:f32[] = add by bz
- in (ca,) }
- ] co u cn
- cq:f32[2,2] = reduce_sum[axes=()] bu
- cr:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
- cs:f32[3,3] = scatter-add[
- dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- unique_indices=True
- update_consts=()
- update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
- ca:f32[] = add by bz
- in (ca,) }
- ] cr q cq
- ct:f32[3,3] = add_any cp cs
- cu:f32[2,2] = reduce_sum[axes=()] bu
- cv:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
- cw:f32[3,3] = scatter-add[
- dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- unique_indices=True
- update_consts=()
- update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
- ca:f32[] = add by bz
- in (ca,) }
- ] cv m cu
- cx:f32[3,3] = add_any ct cw
- cy:f32[2,2] = reduce_sum[axes=()] bu
- cz:f32[3,3] = broadcast_in_dim[broadcast_dimensions=() shape=(3, 3)] 0.0
- da:f32[3,3] = scatter-add[
- dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(), scatter_dims_to_operand_dims=(0, 1))
- indices_are_sorted=True
- mode=GatherScatterMode.PROMISE_IN_BOUNDS
- unique_indices=True
- update_consts=()
- update_jaxpr={ lambda ; by:f32[] bz:f32[]. let
- ca:f32[] = add by bz
- in (ca,) }
- ] cz i cy
- db:f32[3,3] = add_any cx da
- in (db, cm) }
- name=total_energy
- ] a b
- in (c, d) }
Advertisement
Add Comment
Please, Sign In to add comment