Advertisement
gatoatigrado3

haskell framework code for neural nets PA4 Q10

Nov 27th, 2012
248
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.07 KB | None | 0 0
  1. import Prelude hiding (id, (.))
  2. import Control.Arrow
  3. import Control.Applicative
  4. import Control.Category
  5. import Control.Monad
  6. import Control.Monad.Trans.Class
  7.  
  8. import System.Environment
  9.  
  10. import Text.CSV
  11. import Text.Printf
  12.  
  13.  
  14. pfcn_slow :: [[Double]] -> Double
  15. pfcn_slow wts = log $ sum [assn_value h v | h <- hiddens, v <- visibles] where
  16. -- value of an assignment to hidden and visible units
  17. assn_value :: [Double] -> [Double] -> Double
  18. assn_value hs vs = exp $ sum [h_i * wts_ij * v_j
  19. | (h_i, wts_i) <- zip hs wts
  20. , (v_j, wts_ij) <- zip vs wts_i]
  21.  
  22. -- all assignments to hidden and visible variables
  23. hiddens, visibles :: [[Double]]
  24. hiddens = assns (length wts)
  25. visibles = assns (length (head wts))
  26.  
  27. assns :: Int -> [[Double]]
  28. assns 0 = [[]]
  29. assns n = (map (0:) (assns $ n-1)) ++ (map (1:) (assns $ n-1))
  30.  
  31. -- given (log x_i), compute log (sum_i x_i) in a numerically stable manner
  32. logs_to_logsum :: [Double] -> Double
  33. logs_to_logsum logxs
  34. | log_max_xs <- maximum logxs
  35. = log_max_xs + log (sum [exp (logx - log_max_xs) | logx <- logxs])
  36.  
  37. prod_over :: [α] -> (α -> Double) -> Double
  38. prod_over xs = product . flip map xs
  39.  
  40. -- log (\prod_i f(i))
  41. -- which, of course, equals \sum_i log f(i), I just like the conceptual name
  42. log_prod_over :: [α] -> (α -> Double) -> Double
  43. log_prod_over xs = sum . flip map xs . (log .)
  44.  
  45. should_be_eq = (exp $ log_prod_over [1, 2, 3, 4] id, 24)
  46.  
  47. logsumover :: [α] -> (α -> Double) -> Double
  48. -- given some indices {i}, and a function yielding {log f(i)},
  49. -- compute log (sum_i f(i))
  50. logsumover indices i_to_logx = logs_to_logsum (map i_to_logx indices)
  51.  
  52. should_be_eq' = (exp $ logsumover [1, 2, 3] log, 6)
  53.  
  54. pfcn_fast :: [[Double]] -> Double
  55. pfcn_fast wts = YOUR IMPLEMENTATION HERE
  56.  
  57. main = go =<< parseCSVFromFile . head =<< getArgs where
  58. go (Left err) = print "couldn't parse CSV; usage: rbm_part_fcn csv_file"
  59. go (Right csv) = print $ pfcn_fast (map (map read) $ stripNewline csv)
  60. stripNewline csv
  61. | last csv == [""] = init csv
  62. | otherwise = csv
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement