Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- test = gather (segmentedIndices info) vec
- segs :: Segments Int
- segs = fromList (Z:.4) [1,7,0,2]
- starts :: Vector Int
- starts = fromList (Z:.4) [0,10,5,5]
- info :: Acc (Vector (Int,Int))
- info = A.zip (use starts) (use segs)
- vec :: Acc (Vector Float)
- vec = use $ fromList (Z:.20) [0..]
- --
- segmentedIndices
- :: Acc (Vector (Int,Int))
- -> Acc (Vector Int)
- segmentedIndices info
- = A.map A.snd
- $ A.scanl1 (segmented (+)) (A.zip idx ones)
- where
- idx = mkHeadIndices info
- ones = fill (shape idx) 1
- mkHeadIndices
- :: Acc (Vector (Int,Int))
- -> Acc (Vector Int)
- mkHeadIndices info
- = A.init
- $ A.permute const zeros (\ix -> seg!ix A.== 0 ? ( ignore, index1 (offset ! ix) )) start
- where
- (start, seg) = A.unzip info
- (offset, len) = unlift (scanl' (+) 0 seg)
- zeros = fill (index1 $ the len + 1) 0
- segmented
- :: (Exp Int -> Exp Int -> Exp Int)
- -> Exp (Int, Int)
- -> Exp (Int, Int)
- -> Exp (Int, Int)
- segmented f a b =
- let (aF, aV) = unlift a
- (bF, bV) = unlift b
- in
- lift ( aF A..|. bF
- , bF A./= 0 ? (bF, f aV bV))
Add Comment
Please, Sign In to add comment