Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- module Lib where
- import Control.Monad
- import Control.Monad.Primitive
- import Data.Kind
- import Data.Proxy
- import GHC.TypeLits
- import qualified LinearAlgebra as LA
- import System.Random.MWC
- --
- makeBias :: forall n m . (KnownNat n, PrimMonad m) => Gen (PrimState m) -> m (LA.L n 1)
- makeBias gen = LA.matrix <$> replicateM (fromInteger $ natVal (Proxy :: Proxy n)) (uniform gen)
- --
- makeWeight :: forall j k m . (KnownNat j, KnownNat k, PrimMonad m) => Gen (PrimState m) -> m (LA.L j k)
- makeWeight gen = LA.matrix <$> replicateM (j * k) (uniform gen)
- where
- j = fromInteger $ natVal (Proxy :: Proxy j)
- k = fromInteger $ natVal (Proxy :: Proxy k)
- --
- class KnownNats xs
- instance KnownNats '[]
- instance (KnownNat x, KnownNats xs) => KnownNats (x:xs)
- type family Tail (xs :: [a]) :: [a] where
- Tail '[] = '[]
- Tail (x:xs) = xs
- type family Output (xs :: [Nat]) :: Nat where
- Output '[x] = x
- Output (x:xs) = Output xs
- data NetworkCons (xs :: [Nat]) where
- NNil :: (KnownNat a) => NetworkCons '[a]
- NCons :: (KnownNat x, KnownNat y) => LA.L y 1 -> LA.L y x -> NetworkCons (y:ys) -> NetworkCons (x:y:ys)
- feedforward :: (KnownNat x) => LA.L x 1 -> NetworkCons (x:xs) -> LA.L (Output (x:xs)) 1
- feedforward input (NCons bias weight net') = feedforward (weight LA.<> input + bias) net'
- feedforward input NNil = input
- --
- sigmoid :: (Floating a) => a -> a
- sigmoid x = recip $ 1 + exp (negate x)
Add Comment
Please, Sign In to add comment