Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import Prelude hiding (id, (.))
- import Control.Arrow
- import Control.Applicative
- import Control.Category
- import Control.Monad
- import Control.Monad.Trans.Class
- import System.Environment
- import Text.CSV
- import Text.Printf
- pfcn_slow :: [[Double]] -> Double
- pfcn_slow wts = log $ sum [assn_value h v | h <- hiddens, v <- visibles] where
- -- value of an assignment to hidden and visible units
- assn_value :: [Double] -> [Double] -> Double
- assn_value hs vs = exp $ sum [h_i * wts_ij * v_j
- | (h_i, wts_i) <- zip hs wts
- , (v_j, wts_ij) <- zip vs wts_i]
- -- all assignments to hidden and visible variables
- hiddens, visibles :: [[Double]]
- hiddens = assns (length wts)
- visibles = assns (length (head wts))
- assns :: Int -> [[Double]]
- assns 0 = [[]]
- assns n = (map (0:) (assns $ n-1)) ++ (map (1:) (assns $ n-1))
- -- given (log x_i), compute log (sum_i x_i) in a numerically stable manner
- logs_to_logsum :: [Double] -> Double
- logs_to_logsum logxs
- | log_max_xs <- maximum logxs
- = log_max_xs + log (sum [exp (logx - log_max_xs) | logx <- logxs])
- prod_over :: [α] -> (α -> Double) -> Double
- prod_over xs = product . flip map xs
- -- log (\prod_i f(i))
- -- which, of course, equals \sum_i log f(i), I just like the conceptual name
- log_prod_over :: [α] -> (α -> Double) -> Double
- log_prod_over xs = sum . flip map xs . (log .)
- should_be_eq = (exp $ log_prod_over [1, 2, 3, 4] id, 24)
- logsumover :: [α] -> (α -> Double) -> Double
- -- given some indices {i}, and a function yielding {log f(i)},
- -- compute log (sum_i f(i))
- logsumover indices i_to_logx = logs_to_logsum (map i_to_logx indices)
- should_be_eq' = (exp $ logsumover [1, 2, 3] log, 6)
- pfcn_fast :: [[Double]] -> Double
- pfcn_fast wts = YOUR IMPLEMENTATION HERE
- main = go =<< parseCSVFromFile . head =<< getArgs where
- go (Left err) = print "couldn't parse CSV; usage: rbm_part_fcn csv_file"
- go (Right csv) = print $ pfcn_fast (map (map read) $ stripNewline csv)
- stripNewline csv
- | last csv == [""] = init csv
- | otherwise = csv
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement