Guest User

Untitled

a guest
Sep 15th, 2019
192
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. {-# LANGUAGE LambdaCase #-}
  2.  
  3. module Main (main) where
  4.  
  5. import System.Clock (Clock(ProcessCPUTime), getTime, toNanoSecs)
  6. import System.Environment (getArgs)
  7. import System.Exit (exitFailure)
  8. import Text.Printf (printf)
  9. import Data.Array.Unboxed
  10. import Data.Array.IO
  11. import Data.Array.Base
  12. import GHC.Arr (unsafeIndex)
  13.  
  14. -- Matrix multiplication benchmark
  15.  
  16. type Matrix = UArray (Int, Int) Double
  17.  
  18. main :: IO ()
  19. main = do
  20.   [n] <- getArgs >>= \case
  21.     [a] -> pure ([read a :: Int])
  22.     _ -> exitFailure
  23.  
  24.   t1 <- clock
  25.   let a = newMatrix n
  26.       b = newMatrix n
  27.  
  28.   let c' = matrixMult a $ transpose b
  29.  
  30.  printf "% 8.6f\n" (c' ! ((n `div` 2), (n `div` 2)) )
  31.  
  32.   t2 <- clock
  33.   printf "%ds\n" $ (t2 - t1) `div` 1000000000
  34.  
  35. newMatrix :: Int -> Matrix
  36. newMatrix n =
  37.   let tmp = 1 / fromIntegral n / fromIntegral n :: Double in
  38.   array ((0, 0), (pred n, pred n))
  39.       [((i,j), tmp * fromIntegral(i - j) * fromIntegral (i + j))
  40.           | i <- range (0, pred n),
  41.             j <- range (0, pred n) ]
  42.  
  43.  
  44. transpose :: Matrix -> Matrix
  45. transpose x = array resultBounds [((j,i), x!(i,j))
  46.                                      | i <- range (li,ui),
  47.                                        j <- range (lj,uj) ]
  48.   where ((li,lj),(ui,uj))     =  bounds x
  49.         resultBounds          =  ((lj,li),(uj,ui))
  50.  
  51. matrixMult :: Matrix -> Matrix -> Matrix
  52. matrixMult x y    =  array resultBounds [((i,j),
  53.                                          let basei = rowIndex x i
  54.                                              basej = rowIndex y j
  55.                                          in sum [unsafeAt x ( basei + k ) * unsafeAt y ( basej + k )
  56.                                                     | k <- range (lj,uj) ]
  57.                                          )
  58.                                             | i <- range (li,ui),
  59.                                               j <- range (li',ui') ]
  60.                      
  61.   where ((li,lj),(ui,uj))         =  bounds x
  62.         ((li',lj'),(ui',uj'))     =  bounds y
  63.         resultBounds | (lj,uj)==(lj',uj')    =  ((li,li'),(ui,ui'))
  64.                      | otherwise             = error "matMult: incompatible bounds"
  65.         rowIndex  arr n           = index (bounds arr) (n,0)
  66.  
  67. clock :: IO Integer
  68. clock = toNanoSecs <$> getTime ProcessCPUTime
Add Comment
Please, Sign In to add comment