Advertisement
Guest User

Untitled

a guest
Sep 22nd, 2019
109
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.76 KB | None | 0 0
  1. {-# LANGUAGE DerivingVia #-}
  2.  
  3. {-|
  4. Module : Bandit.Exp3
  5. Copyright : (c) 2019, UChicago Argonne, LLC.
  6. License : MIT
  7. Maintainer : fre@freux.fr
  8. -}
  9. module Bandit.Exp3
  10. ( Exp3 (..)
  11. , Weight (..)
  12. )
  13. where
  14.  
  15. import Bandit.Class
  16. import Control.Lens
  17. import qualified Data.Aeson as A
  18. import Data.Data
  19. import Data.Generics.Product
  20. import Data.JSON.Schema
  21. import Data.MessagePack
  22. import Data.Random
  23. import qualified Data.Random.Distribution.Categorical as DC
  24. import qualified Data.Random.Sample as RS
  25. import Dhall hiding (field)
  26. import NRM.Classes.Messaging
  27. import Protolude
  28. import Refined
  29.  
  30. data Exp3 a
  31. = Exp3
  32. { t :: Int
  33. , lastAction :: Maybe a
  34. , k :: Int
  35. , ws :: [Weight a]
  36. }
  37. deriving (Generic)
  38.  
  39. newtype Probability = Probability {getProbability :: Double}
  40. deriving (JSONSchema, A.ToJSON, A.FromJSON) via GenericJSON Probability
  41. deriving (Show, Generic, Data, MessagePack, Interpret, Inject)
  42.  
  43. newtype CumulativeLoss = CumulativeLoss {getCumulativeLoss :: Double}
  44. deriving (JSONSchema, A.ToJSON, A.FromJSON) via GenericJSON CumulativeLoss
  45. deriving (Show, Generic, Data, MessagePack, Interpret, Inject)
  46.  
  47. data Weight a
  48. = Weight
  49. { probability :: Probability
  50. , cumulativeLoss :: CumulativeLoss
  51. , action :: a
  52. }
  53. deriving (Generic)
  54.  
  55. instance (Eq a) => Bandit (Exp3 a) Set a (Refined (FromTo 0 1) Double) where
  56.  
  57. init as = Exp3
  58. { t = 1
  59. , lastAction = Nothing
  60. , k = length as
  61. , ws = toList as <&> Weight (Probability 1) (CumulativeLoss 0)
  62. }
  63.  
  64. step (unrefine -> l) =
  65. get <&> lastAction >>= \case
  66. Nothing -> pickAction
  67. Just oldAction -> do
  68. field @"ws" %=
  69. fmap (\w -> if action w == oldAction then updateCumLoss l w else w)
  70. t <- use $ field @"t"
  71. k <- use $ field @"k"
  72. field @"ws" %= recompute t k
  73. field @"t" += 1
  74. pickAction
  75.  
  76. pickAction :: (MonadRandom m, MonadState (Exp3 a) m) => m a
  77. pickAction = get >>= s >>= btw (assign (field @"lastAction") . Just)
  78. where
  79. s bandit = RS.sample . DC.fromWeightedList $ ws bandit <&> w2tuple
  80. w2tuple (Weight p _ action) = (getProbability p, action)
  81.  
  82. updateCumLoss :: Double -> Weight a -> Weight a
  83. updateCumLoss l w@(Weight (Probability p) (CumulativeLoss cL) _) =
  84. w & field @"cumulativeLoss" .~ CumulativeLoss (cL + (l / p))
  85.  
  86. recompute :: Int -> Int -> [Weight a] -> [Weight a]
  87. recompute t k ws = updatep <$> ws
  88. where
  89. updatep w@(Weight _ (CumulativeLoss cL) _) =
  90. w & field @"probability" . field @"getProbability" .~
  91. expw cL /
  92. denom
  93. expw cL =
  94. exp (- sqrt (2.0 * log (fromIntegral k) / fromIntegral (t * k)) * cL)
  95. denom = getSum $ foldMap denomF ws
  96. denomF (getCumulativeLoss . cumulativeLoss -> cL) = Sum $ expw cL
  97.  
  98. btw :: (Functor f) => (t -> f b) -> t -> f t
  99. btw k x = x <$ k x
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement