Guest User

Untitled

a guest
Nov 23rd, 2017
64
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.39 KB | None | 0 0
  1. module Lib where
  2.  
  3. import Control.Monad
  4. import Control.Monad.Primitive
  5. import Data.Kind
  6. import Data.Proxy
  7. import GHC.TypeLits
  8. import qualified LinearAlgebra as LA
  9. import System.Random.MWC
  10.  
  11. --
  12.  
  13. makeBias :: forall n m . (KnownNat n, PrimMonad m) => Gen (PrimState m) -> m (LA.L n 1)
  14. makeBias gen = LA.matrix <$> replicateM (fromInteger $ natVal (Proxy :: Proxy n)) (uniform gen)
  15.  
  16. --
  17.  
  18. makeWeight :: forall j k m . (KnownNat j, KnownNat k, PrimMonad m) => Gen (PrimState m) -> m (LA.L j k)
  19. makeWeight gen = LA.matrix <$> replicateM (j * k) (uniform gen)
  20. where
  21. j = fromInteger $ natVal (Proxy :: Proxy j)
  22. k = fromInteger $ natVal (Proxy :: Proxy k)
  23.  
  24. --
  25.  
  26. class KnownNats xs
  27. instance KnownNats '[]
  28. instance (KnownNat x, KnownNats xs) => KnownNats (x:xs)
  29.  
  30. type family Tail (xs :: [a]) :: [a] where
  31. Tail '[] = '[]
  32. Tail (x:xs) = xs
  33.  
  34. type family Output (xs :: [Nat]) :: Nat where
  35. Output '[x] = x
  36. Output (x:xs) = Output xs
  37.  
  38. data NetworkCons (xs :: [Nat]) where
  39. NNil :: (KnownNat a) => NetworkCons '[a]
  40. NCons :: (KnownNat x, KnownNat y) => LA.L y 1 -> LA.L y x -> NetworkCons (y:ys) -> NetworkCons (x:y:ys)
  41.  
  42. feedforward :: (KnownNat x) => LA.L x 1 -> NetworkCons (x:xs) -> LA.L (Output (x:xs)) 1
  43. feedforward input (NCons bias weight net') = feedforward (weight LA.<> input + bias) net'
  44. feedforward input NNil = input
  45.  
  46. --
  47.  
  48. sigmoid :: (Floating a) => a -> a
  49. sigmoid x = recip $ 1 + exp (negate x)
Add Comment
Please, Sign In to add comment