Guest User

Untitled

a guest
Oct 21st, 2018
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 11.54 KB | None | 0 0
  1. {-# LANGUAGE ScopedTypeVariables #-}
  2. {-# LANGUAGE TypeOperators #-}
  3. {-# LANGUAGE FlexibleContexts #-}
  4. {-# LANGUAGE FlexibleInstances #-}
  5. {-# LANGUAGE CPP #-}
  6.  
  7. -- Naive translation of shen_anova1_batch.m into Accelerate.
  8.  
  9. module Main (main) where
  10.  
  11. import Data.Array.Accelerate as A
  12. import Data.Array.Accelerate.CUDA as R
  13. import Prelude as P
  14. import Data.List hiding (group)
  15. import Data.List.Split (splitEvery)
  16. -- import Data.Set as S
  17.  
  18. import Control.Monad (when, mapM_, forM_)
  19.  
  20. import Data.Array.Unboxed (IArray, UArray, elems, listArray)
  21. import System.Random
  22.  
  23. type Matrix a = Array DIM2 a
  24. type EI = Exp Int -- Shorthand for the lazy.
  25. runE x = R.run (unit x) -- Run for scalars
  26.  
  27. dbg = False -- TOGGLE me.
  28.  
  29. --------------------------------------------------------------------------------
  30.  
  31. -- Test harness -- set up and run the test.
  32. test_anova1 :: Int -> Int -> IO ()
  33. test_anova1 nsbj width = do
  34. let
  35. nsnp = 1
  36. nsbj' = constant nsbj
  37. area = constant (width*width)
  38. -- imdata = rand( width*width, nsbj )
  39. imdata :: Acc (Matrix Float) = generate (index2 area nsbj') inorder
  40. inorder x = let (r, c) = unlift (unindex2 x) in
  41. A.fromIntegral (r + c * area) -- Count by ones in row major..
  42.  
  43. -- Ammendment -- for now fixing nsnp == 1 and just doing a vector rather than a matrix for snpdata:
  44. snpdata :: Acc (Vector Int)
  45. snpdata = generate (index1 nsbj') ((`mod` 3) . unindex1) -- Return 0,1, or 2
  46.  
  47. putStrLn "=== Accelerate naive anova1 batch ==="
  48.  
  49. putStrLn$ "Image width: "++ show width
  50. putStrLn$ "SNP category membership across patients:"++ pp(R.run snpdata)
  51.  
  52. -- tic
  53. -- forM [0 .. nsnp-1] $ \ i -> do
  54. anova1_batch imdata snpdata
  55. -- fmap1(:,i) = F; ?????
  56. -- toc
  57. return ()
  58.  
  59.  
  60. ----------------------------------------------------------------------------------------------------
  61. -- Shen's implementation of anova1 in a batch mode
  62. -- 'voxels' is a matrix
  63. -- 'group' is a column vector
  64. -- 'voxel' columns have the same group memebership
  65. anova1_batch :: Acc (Matrix Float) -> Acc (Vector Int) -> IO ()
  66. -- RRN: I believe group should represent a set of individual's single
  67. -- nucleotide values at a specific location (a specific SNP).
  68. -- 'voxel's represents a /flat/ row of voxels for each individual in
  69. -- the group. Here we analyze how variance at each voxel correlates
  70. -- with the different categories of individuals created by 'group'.
  71. --
  72. -- Specifically, in this case there are three /subgroups/ within
  73. -- 'group', corresponding to snp=0, snp=1, snp=2.
  74. anova1_batch voxels group = do
  75. let d = shape voxels
  76. -- n = total number of voxels in each image
  77. -- r = number of parallel anova runs == number of individuals in group
  78. (n::EI, r::EI) = unlift$ unindex2 d
  79.  
  80. -- FIXME:
  81. -- uniqueSnps = unique (R.run group) -- 'g': All the unique SNP settings.
  82. -- numUniqueSNPs = size g -- 'p': How many unique values.
  83. uniqueSNPs = [0,1,2] -- TEMP, FIXME -- hardcoding
  84. numUniqueSNPs = 3 -- TEMP, FIXME -- hardcoding
  85. -- NOTE: We could try to keep everything within Accelerate, but
  86. -- for the "outermost" loop here, over different SNP values, we
  87. -- allow ourselves the rconvenience of using lists, and keeping
  88. -- things in regular Haskell.
  89.  
  90. -- First, how much variance is there across ALL voxels (per sample):
  91. var_v = variance1of2 voxels :: Acc (Vector Float)
  92. ss_T = A.map (df_T *) var_v;
  93. -- Next, we compute per-group variance.
  94. ------------------------------------------------------------
  95. -- Group is a mapping between index->groupID, and we effectively want to invert it
  96. -- to get groupID->indices. To accomplish that we first explicitly attach indices:
  97. zippedGroup = A.zip (iota r) group
  98.  
  99. -- For each observed SNP value {0,1,2} ...
  100. idxs = P.map (\ snpVal ->
  101. let
  102. -- Here we want the index of *all* of the voxels belonging to
  103. -- individuals with a particular snpVal.
  104. idx = A.map A.fst $ filterAcc isMatch zippedGroup
  105. isMatch pr = let (_::EI,thisSnp::EI) = unlift pr in
  106. thisSnp ==* constant snpVal
  107. in idx
  108. ) uniqueSNPs
  109. -- Now we're ready to extract the set of voxels that belong to this group of patients:
  110. groupedVoxels :: [Acc (Matrix Float)]
  111. -- rowGroups = selectRows idx voxels
  112. groupedVoxels = P.map (`selectRows` voxels) idxs
  113. -- Whew, now we have finally achieved unflat = X(idx,:) in the original code....
  114.  
  115. -- Next we look at variance within each voxel for individuals in THIS snp category:
  116. -- The result is a per-voxel variance:
  117. -- variances = A.map (* (A.fromIntegral$ newCols - 1)) (variance2of2 unflat)
  118. variances :: [Acc (Vector Float)] = P.map variance2of2 groupedVoxels
  119.  
  120. -- What is the size of each group, minus 1:
  121. lens :: [Exp Float] = P.map ((\x -> A.fromIntegral (x - 1)) . unindex1 . shape) idxs
  122.  
  123. -- Do a sum reduction, the result is an array of per-voxel statistics:
  124. ss_W = foldl1 (.+) (P.zipWith (liftScalar (.*)) lens variances)
  125. ------------------------------------------------------------
  126.  
  127. df_B :: Exp Float = A.fromIntegral$ numUniqueSNPs - 1
  128. df_W :: Exp Float = A.fromIntegral$ n - numUniqueSNPs
  129. df_T :: Exp Float = A.fromIntegral$ n - 1
  130.  
  131. -- Well, this is quite verbose due to there being no scalar/array overloading:
  132. ms_W = liftScalarR (./) ss_W df_W
  133. ss_B = ss_T .- ss_W;
  134. ms_B = liftScalarR (./) ss_B df_B
  135. f_statistic = ms_B ./ ms_W
  136.  
  137. putStrLn$ "Finally, here's the F statistic: " ++ pp(run ss_W)
  138. return ()
  139.  
  140.  
  141. -- TEMP -- hardcoding this to a specific dimensionality until I fix the general version below
  142. -- This sums along the RIGHT of two dimensions (INNERMOST).
  143. -- If interpreted as a (row,column) matrix, this collapses each row.
  144. variance2of2 :: (Elt a, IsFloating a)
  145. => Acc (Array DIM2 a) -> Acc (Array DIM1 a)
  146. variance2of2 arr =
  147. A.map (/ (denom-1))
  148. (A.zipWith (-) sum2
  149. (A.zipWith (*) mean sum1))
  150. where
  151. mean = A.map (/ denom) sum1
  152. denom = A.fromIntegral rows
  153. Z :. (rows::EI) :. (cols::EI) = unlift sh1
  154. sh1 :: Exp ( Z :. Int :. Int) = shape arr
  155. -- Sum along LEFT dimension:
  156. sum1 = fold (+) 0 arr
  157. sum2 = fold (+) 0 sqrs
  158. sqrs = A.zipWith (*) arr arr
  159.  
  160.  
  161. -- TEMP -- hardcoding this to a specific dimensionality until I fix the general version below
  162. -- This iterates along the LEFT of two dimensions (OUTERMOST).
  163. -- If interpreted as a (row,column) matrix, this collapses each column,
  164. -- which is the same behavior as the var() function in matlab applied to a matrix.
  165. variance1of2 :: (Elt a, IsFloating a)
  166. => Acc (Array DIM2 a) -> Acc (Array DIM1 a)
  167. variance1of2 arr =
  168. A.map (/ (denom-1))
  169. (A.zipWith (-) sum2
  170. (A.zipWith (*) mean sum1))
  171. where
  172. mean = A.map (/ denom) sum1
  173. denom = A.fromIntegral rows
  174. Z :. (rows::EI) :. (cols::EI) = unlift sh1
  175. sh1 :: Exp ( Z :. Int :. Int) = shape arr
  176. -- Sum along LEFT dimension:
  177. swapped = swap2Dims arr
  178. sum1 = fold (+) 0 swapped
  179. sum2 = fold (+) 0 sqrs
  180. sqrs = A.zipWith (*) swapped swapped
  181.  
  182.  
  183. -- A swap function fixed to exactly two dimensions.
  184. swap2Dims :: (Elt e, Num e) => Acc (Matrix e) -> Acc (Matrix e)
  185. swap2Dims a =
  186. backpermute new_extent swap2 a -- Perform a gather.
  187. where
  188. new_extent = swap2 (shape a)
  189. swap2 :: Exp DIM2 -> Exp DIM2
  190. swap2 ind = let (Z :.i2 :.i1) = unlift ind in
  191. lift (Z :. (i1::EI) :. (i2::EI))
  192.  
  193. ----------------------------------------------------------------------------------------------------
  194. -- Let's print matrices nicely.
  195.  
  196. padleft n str | length str >= n = str
  197. padleft n str | otherwise = P.take (n - length str) (repeat ' ') ++ str
  198.  
  199. class NiceShow a where
  200. pp :: a -> String
  201.  
  202. instance Show a => NiceShow (Array DIM1 a) where
  203. pp arr =
  204. capends$ concat$
  205. intersperse " " $
  206. P.map (padleft maxpad) ls
  207. where
  208. ls = P.map show $ toList arr
  209. maxpad = maximum$ P.map length ls
  210.  
  211. capends x = "| "++x++" |"
  212.  
  213. -- This could be much more efficient:
  214. instance Show a => NiceShow (Array DIM2 a) where
  215. pp arr = concat $
  216. intersperse "\n" $
  217. P.map (capends .
  218. concat .
  219. intersperse " " .
  220. P.map (padleft maxpad))
  221. rowls
  222. where (Z :. rows :. cols) = arrayShape arr
  223. ls = P.map show $ toList arr
  224. maxpad = maximum$ P.map length ls
  225. rowls = splitEvery cols ls
  226.  
  227. ----------------------------------------------------------------------------------------------------
  228. -- General Utility functions:
  229. ----------------------------------------------------------------------------------------------------
  230.  
  231. -- Project a subset of the columns (second dim) from a two dimensional matrix.
  232. --
  233. -- @selectRows sl xs@ is similar to xs(sl,:) in Matlab.
  234. selectRows :: Elt e => Acc (Vector Int) -> Acc (Array DIM2 e) -> Acc (Array DIM2 e)
  235. selectRows sl xs =
  236. let Z :. rows = unlift $ shape sl
  237. Z :. _ :. cols = unlift $ shape xs :: Z :. Exp Int :. Exp Int
  238. in
  239. backpermute
  240. (index2 rows cols)
  241. (\ix -> let Z :. j :. i = unlift ix in index2 (sl ! index1 j) i)
  242. xs
  243.  
  244. -- Filter -- from accelerate-examples
  245. -- ------
  246. filterAcc :: Elt a
  247. => (Exp a -> Exp Bool)
  248. -> Acc (Vector a)
  249. -> Acc (Vector a)
  250. filterAcc p arr
  251. = let -- arr = A.use vec
  252. flags = A.map (boolToInt . p) arr
  253. (targetIdx, len) = A.scanl' (+) 0 flags
  254. arr' = A.backpermute (index1 $ the len) id arr
  255. in
  256. A.permute const arr' (\ix -> flags!ix ==* 0 ? (ignore, index1 $ targetIdx!ix)) arr
  257. -- FIXME: This is abusing 'permute' in that the first two arguments are
  258. -- only justified because we know the permutation function will
  259. -- write to each location in the target exactly once.
  260. -- Instead, we should have a primitive that directly encodes the
  261. -- compaction pattern of the permutation function.
  262.  
  263. -- Create a vector containing integers in the range [0,n)
  264. -- iota :: (Elt n, IsNum n) => Exp Int -> Acc (Vector n)
  265. iota :: Exp Int -> Acc (Vector Int)
  266. iota n = A.generate (index1 n) (A.fromIntegral . unindex1)
  267.  
  268.  
  269. -- Perform a scalar/array operation by lifting an array/array binary
  270. -- operator.
  271. liftScalar :: (Shape sh, Elt e)
  272. => (Acc (Array sh e) -> Acc (Array sh e) -> Acc (Array sh e))
  273. -> Exp e -> Acc (Array sh e) -> Acc (Array sh e)
  274. liftScalar op left right = op left' right
  275. where
  276. -- I can't figure out how to use replicate here. This is a spurious data dependency introduced by map:
  277. left' = A.map (\_ -> left) right
  278. -- left' = A.replicate (shape right) (A.unit left)
  279.  
  280.  
  281. -- Same but flipped:
  282. liftScalarR :: (Shape sh, Elt e)
  283. => (Acc (Array sh e) -> Acc (Array sh e) -> Acc (Array sh e))
  284. -> Acc (Array sh e) -> Exp e -> Acc (Array sh e)
  285. liftScalarR op = flip (liftScalar (flip op))
  286.  
  287. --------------------------------------------------------------------------------
  288.  
  289. -- Elementwise operation matching matlab:
  290.  
  291. infixl 7 ./
  292. infixl 7 .*
  293. infixl 6 .+
  294. infixl 6 .-
  295.  
  296. a ./ b = A.zipWith (/) a b
  297. a .* b = A.zipWith (*) a b
  298. a .+ b = A.zipWith (+) a b
  299. a .- b = A.zipWith (-) a b
  300.  
  301. --------------------------------------------------------------------------------
  302.  
  303. small = test_anova1 6 3
  304. big = test_anova1 300 256
  305. main = small
  306.  
  307. ----------------------------------------------------------------------------------------------------
Add Comment
Please, Sign In to add comment