Advertisement
Guest User

Untitled

a guest
Oct 21st, 2019
184
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.71 KB | None | 0 0
  1. import random
  2. import math
  3.  
  4. def get_arms(prior, n_arms):
  5.     beta_arms = [list(prior) for k in range(n_arms)]
  6.     return beta_arms
  7.  
  8.  
  9. def select_arm_with_thompson_sampling_beta_dist(arms):
  10.     samples = [random.gauss(*arm) for arm in arms]
  11.     best_arm = max(enumerate(samples), key=lambda x:x[1])
  12.     return best_arm[0]
  13.  
  14.  
  15. def play_arm(arm_i):
  16.     # interact with the random environment
  17.     # in this case we assume that the best arm has a reward prob of 80%
  18.     # and the rest arms has a 40% win prob
  19.     best_arm_hidden_from_user = 3
  20.     if arm_i == best_arm_hidden_from_user:
  21.         return random.randint(15, 30)
  22.  
  23.     else:
  24.         return random.randint(0, 15)
  25.  
  26. def update_arms_with_conjugate_prior(arm, reward):
  27.     alpha = arm[0]
  28.     beta = arm[1]
  29.  
  30.     if reward:
  31.         return alpha + 1, beta
  32.  
  33.     else:
  34.         return alpha, beta + 1
  35.  
  36. def update_arms_with_gaussian(arm, observation):
  37.     mean0 = arm[0]
  38.     var0 = arm[1]
  39.  
  40.     mean = (1 / ((1/var0) + (1/5))) * ((mean0/var0) + (observation/5))
  41.     var = 1 / ((1/var0) + 1/5)
  42.  
  43.     return mean, var
  44.  
  45. T = 100
  46. n_arms = 5
  47.  
  48. #       mean  var
  49. prior = (15, 5)
  50.  
  51. arms = get_arms(prior, n_arms)
  52. rewards = 0
  53.  
  54. for t in range(T):
  55.     best_arm = select_arm_with_thompson_sampling_beta_dist(arms)
  56.     reward_t = play_arm(best_arm)
  57.     rewards += reward_t
  58.  
  59.     #updated_arm_dist = update_arms_with_conjugate_prior(arms[best_arm], reward_t)
  60.     updated_arm_dist = update_arms_with_gaussian(arms[best_arm], reward_t)
  61.     arms[best_arm] = updated_arm_dist
  62.  
  63.  
  64. print("-----------------")
  65. print("total reward: {}".format(rewards))
  66. print("----- arms ------")
  67. for k, arm in enumerate(arms):
  68.     print("arm {} => dist: ({}, {}) ".format(k, arm[0], arm[1]))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement